apply new black formatting to all files

This commit is contained in:
Desh Raj 2022-11-16 13:06:43 -05:00
parent aa7bae1ecd
commit d110b04ad3
440 changed files with 6789 additions and 14532 deletions

View File

@ -45,17 +45,18 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
# See https://github.com/psf/black/issues/2964 # Click issue fixed in https://github.com/psf/black/pull/2966
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8 - name: Run flake8
shell: bash shell: bash
working-directory: ${{github.workspace}} working-directory: ${{github.workspace}}
run: | run: |
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
flake8 . --count --show-source --statistics flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
- name: Run black - name: Run black
shell: bash shell: bash

View File

@ -1,26 +1,38 @@
repos: repos:
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 21.6b0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
args: [--line-length=80] args: ["--line-length=88"]
additional_dependencies: ['click==8.0.1'] additional_dependencies: ['click==8.0.1']
exclude: icefall\/__init__\.py exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 3.9.2 rev: 5.0.4
hooks: hooks:
- id: flake8 - id: flake8
args: [--max-line-length=80] args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
# What are we ignoring here?
# E203: whitespace before ':'
# E266: too many leading '#' for block comment
# E501: line too long
# F401: module imported but unused
# E402: module level import not at top of file
# F403: 'from module import *' used; unable to detect undefined names
# F841: local variable is assigned to but never used
# W503: line break before binary operator
# In addition, the default ignore list is:
# E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.9.2 rev: 5.10.1
hooks: hooks:
- id: isort - id: isort
args: [--profile=black, --line-length=80] args: ["--profile=black"]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1 rev: v4.2.0
hooks: hooks:
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
- id: end-of-file-fixer - id: end-of-file-fixer

View File

@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall WORKDIR /workspace/icefall

View File

@ -19,4 +19,3 @@ It can be downloaded from `<https://www.openslr.org/33/>`_
tdnn_lstm_ctc tdnn_lstm_ctc
conformer_ctc conformer_ctc
stateless_transducer stateless_transducer

View File

@ -6,4 +6,3 @@ TIMIT
tdnn_ligru_ctc tdnn_ligru_ctc
tdnn_lstm_ctc tdnn_lstm_ctc

View File

@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -50,15 +50,15 @@ def get_parser():
"-n", "-n",
default=1, default=1,
type=int, type=int,
help="number of characters to split, i.e., \ help=(
aabb -> a a b b with -n 1 and aa bb with -n 2", "number of characters to split, i.e., aabb -> a a b"
" b with -n 1 and aa bb with -n 2"
),
) )
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument( parser.add_argument("--space", default="<space>", type=str, help="space symbol")
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,9 +66,7 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument( parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -108,8 +106,7 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id token_table[txt] if txt in token_table else oov_id for txt in text
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -135,9 +132,7 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")( f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -113,4 +113,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/prepare_char.py ./local/prepare_char.py
fi fi
fi fi

View File

@ -81,10 +81,12 @@ class Aidatatang_200zhAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc."
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -96,75 +98,91 @@ class Aidatatang_200zhAsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help=(
"single batch. You can reduce it if it causes CUDA OOM.", "Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, the batches will come from buckets of " help=(
"similar duration (saves padding frames).", "When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=300,
help="The number of buckets for the DynamicBucketingSampler" help=(
"(you might want to increase it for larger datasets).", "The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, utterances (cuts) will be concatenated " help=(
"to minimize the amount of padding.", "When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help="Determines the maximum duration of a concatenated cut " help=(
"relative to the duration of the longest cut in a batch.", "Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help="The amount of padding (in seconds) inserted between " help=(
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.", "noise augmentation is used."
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, use on-the-fly cut mixing and feature " help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available.", "if available."
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled (=default), the examples will be " help=(
"shuffled for each epoch.", "When enabled (=default), the examples will be shuffled for each epoch."
),
) )
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, each batch will have the " help=(
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.", "were used to construct it."
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that collect the batches.",
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -178,18 +196,22 @@ class Aidatatang_200zhAsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help="Used only when --enable-spec-aug is True. " help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp.", "A value less than 1 means to disable time warp."
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, select noise from MUSAN and mix it" help=(
"with training dataset. ", "When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
) )
def train_dataloaders( def train_dataloaders(
@ -205,24 +227,20 @@ class Aidatatang_200zhAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " "Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -237,9 +255,7 @@ class Aidatatang_200zhAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -282,9 +298,7 @@ class Aidatatang_200zhAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -340,9 +354,7 @@ class Aidatatang_200zhAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -69,11 +69,7 @@ from beam_search import (
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -92,25 +88,30 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--batch", "--batch",
type=int, type=int,
default=None, default=None,
help="It specifies the batch checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -192,8 +193,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,9 +249,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -266,10 +264,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -315,11 +310,7 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -390,9 +381,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -425,8 +414,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)

View File

@ -62,17 +62,20 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -103,8 +106,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -173,9 +175,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -85,9 +85,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -112,10 +114,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -162,8 +166,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -193,10 +196,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -257,9 +259,7 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,10 +284,7 @@ def main():
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -339,9 +336,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -81,9 +81,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -187,42 +185,45 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help="The prune range for rnnt loss, it means how many symbols(context)" help=(
"we are using to compute the loss", "The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help="The scale to smooth the loss with lm " help=(
"(output of prediction network) part.", "The scale to smooth the loss with lm (output of prediction network) part."
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)part.",
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help="To get pruning ranges, we will calculate a simple version" help=(
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss."
),
) )
parser.add_argument( parser.add_argument(
@ -542,22 +543,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -711,9 +705,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")

View File

@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm( self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout( src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__( def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear( q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
query, in_proj_weight, in_proj_bias 3, dim=-1
).chunk(3, dim=-1) )
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError( raise RuntimeError("The size of the 2D attn_mask is not correct.")
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError( raise RuntimeError("The size of the 3D attn_mask is not correct.")
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format( "attn_mask's dimension {} is not supported".format(attn_mask.dim())
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if ( if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
" instead."
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -58,16 +58,19 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=49, default=49,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=20, default=20,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -401,9 +404,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -431,9 +432,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -441,9 +440,7 @@ def save_results(
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: if enable_log:
logging.info( logging.info("Wrote detailed error stats to {}".format(errs_filename))
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -562,9 +559,7 @@ def main():
eos_id=eos_id, eos_id=eos_id,
) )
save_results( save_results(params=params, test_set_name=test_set, results_dict=results_dict)
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -40,17 +40,20 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=84, default=84,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=25, default=25,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -157,9 +160,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -46,27 +46,29 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
"--tokens-file", "--tokens-file",
type=str, type=str,
help="Path to tokens.txt" "Used only when method is ctc-decoding", help="Path to tokens.txtUsed only when method is ctc-decoding",
) )
parser.add_argument( parser.add_argument(
"--words-file", "--words-file",
type=str, type=str,
help="Path to words.txt" "Used when method is NOT ctc-decoding", help="Path to words.txtUsed when method is NOT ctc-decoding",
) )
parser.add_argument( parser.add_argument(
"--HLG", "--HLG",
type=str, type=str,
help="Path to HLG.pt." "Used when method is NOT ctc-decoding", help="Path to HLG.pt.Used when method is NOT ctc-decoding",
) )
parser.add_argument( parser.add_argument(
@ -163,10 +165,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
return parser return parser
@ -210,10 +214,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -274,9 +277,7 @@ def main():
logging.info("Decoding started") logging.info("Decoding started")
features = fbank(waves) features = fbank(waves)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding # Note: We don't use key padding mask for attention during decoding
with torch.no_grad(): with torch.no_grad():
@ -371,9 +372,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d( nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
) )
) )
layers.append( layers.append(
torch.nn.MaxPool2d( torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
) )
cur_channels = block_dim cur_channels = block_dim
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.out = nn.Linear( self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.

View File

@ -16,9 +16,8 @@
# limitations under the License. # limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch import torch
from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling(): def test_conv2d_subsampling():

View File

@ -382,9 +382,7 @@ def compute_loss(
# #
# See https://github.com/k2-fsa/icefall/issues/97 # See https://github.com/k2-fsa/icefall/issues/97
# for more details # for more details
unsorted_token_ids = graph_compiler.texts_to_ids( unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
supervisions["text"]
)
att_loss = mmodel.decoder_forward( att_loss = mmodel.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
@ -520,9 +518,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -630,9 +626,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm, norm=decoder_norm,
) )
self.decoder_output_layer = torch.nn.Linear( self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss() self.decoder_criterion = LabelSmoothingLoss()
else: else:
@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x, supervision
)
x = self.ctc_output(encoder_memory) x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
@ -266,23 +262,17 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device) ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError( raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -836,9 +818,7 @@ def encoder_padding_mask(
1, 1,
).to(torch.int32) ).to(torch.int32)
lengths = [ lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)): for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples # Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item() sequence_idx = supervision_segments[idx, 0].item()
@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask return mask
def decoder_padding_mask( def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input. """Generate a length mask for input.
The masked position are filled with True, The masked position are filled with True,

View File

@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm( self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout( src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__( def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear( q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
query, in_proj_weight, in_proj_bias 3, dim=-1
).chunk(3, dim=-1) )
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError( raise RuntimeError("The size of the 2D attn_mask is not correct.")
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError( raise RuntimeError("The size of the 3D attn_mask is not correct.")
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format( "attn_mask's dimension {} is not supported".format(attn_mask.dim())
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if ( if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
" instead."
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -59,16 +59,19 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=49, default=49,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=20, default=20,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -413,9 +416,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -443,9 +444,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -453,9 +452,7 @@ def save_results(
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: if enable_log:
logging.info( logging.info("Wrote detailed error stats to {}".format(errs_filename))
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@ -550,9 +547,7 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save( torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return return
model.to(device) model.to(device)
@ -581,9 +576,7 @@ def main():
eos_id=eos_id, eos_id=eos_id,
) )
save_results( save_results(params=params, test_set_name=test_set, results_dict=results_dict)
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d( nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(), nn.ReLU(),
) )
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
) )
) )
layers.append( layers.append(
torch.nn.MaxPool2d( torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
) )
cur_channels = block_dim cur_channels = block_dim
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.out = nn.Linear( self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.

View File

@ -511,9 +511,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -625,9 +623,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm, norm=decoder_norm,
) )
self.decoder_output_layer = torch.nn.Linear( self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss() self.decoder_criterion = LabelSmoothingLoss()
else: else:
@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x, supervision
)
x = self.ctc_output(encoder_memory) x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
@ -266,23 +262,17 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device) ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence( ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence( ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask # TODO: Use length information to create the decoder padding mask
@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError( raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -836,9 +818,7 @@ def encoder_padding_mask(
1, 1,
).to(torch.int32) ).to(torch.int32)
lengths = [ lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)): for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples # Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item() sequence_idx = supervision_segments[idx, 0].item()
@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask return mask
def decoder_padding_mask( def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input. """Generate a length mask for input.
The masked position are filled with True, The masked position are filled with True,

View File

@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -76,11 +76,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -118,9 +114,11 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
@ -188,8 +186,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,9 +246,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -263,10 +258,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -310,11 +302,7 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -387,9 +375,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -415,9 +401,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -428,8 +412,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -473,9 +456,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -504,8 +485,7 @@ def main():
] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(

View File

@ -50,11 +50,7 @@ from pathlib import Path
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import str2bool from icefall.utils import str2bool
@ -87,9 +83,11 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
@ -120,8 +118,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -157,8 +154,7 @@ def main():
] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -191,9 +187,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = ( filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename)) model.save(str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
else: else:
@ -201,17 +195,14 @@ def main():
# Save it using a format so that it can be loaded # Save it using a format so that it can be loaded
# by :func:`load_checkpoint` # by :func:`load_checkpoint`
filename = ( filename = (
params.exp_dir params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
) )
torch.save({"model": model.state_dict()}, str(filename)) torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -87,9 +87,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -115,10 +117,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -165,15 +169,16 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help="Maximum number of symbols per frame. " help=(
"Use only when --method is greedy_search", "Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -196,10 +201,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -256,13 +260,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -310,9 +310,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -329,9 +327,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -49,7 +49,6 @@ import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
@ -75,9 +74,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -203,8 +200,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -227,42 +223,45 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help="The prune range for rnnt loss, it means how many symbols(context)" help=(
"we are using to compute the loss", "The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help="The scale to smooth the loss with lm " help=(
"(output of prediction network) part.", "The scale to smooth the loss with lm (output of prediction network) part."
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)part.",
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help="To get pruning ranges, we will calculate a simple version" help=(
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss."
),
) )
parser.add_argument( parser.add_argument(
@ -561,11 +560,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -593,23 +588,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -725,9 +713,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -1029,9 +1015,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -121,20 +121,24 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help="Whether to load averaged model. Currently it only supports " help=(
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. "
),
) )
parser.add_argument( parser.add_argument(
@ -202,8 +206,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -263,9 +266,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -277,10 +278,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -324,11 +322,7 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -401,9 +395,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -429,9 +421,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -442,8 +432,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -488,9 +477,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -518,13 +505,12 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -551,13 +537,12 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -586,7 +571,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
f"Calculating the averaged model over epoch range from " "Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)

View File

@ -88,20 +88,24 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=True, default=True,
help="Whether to load averaged model. Currently it only supports " help=(
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. "
),
) )
parser.add_argument( parser.add_argument(
@ -132,8 +136,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -166,13 +169,12 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -195,13 +197,12 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -229,7 +230,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
f"Calculating the averaged model over epoch range from " "Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -252,9 +253,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = ( filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename)) model.save(str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
else: else:
@ -262,17 +261,14 @@ def main():
# Save it using a format so that it can be loaded # Save it using a format so that it can be loaded
# by :func:`load_checkpoint` # by :func:`load_checkpoint`
filename = ( filename = (
params.exp_dir params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
) )
torch.save({"model": model.state_dict()}, str(filename)) torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}") logging.info(f"Saved to {filename}")
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -84,9 +84,7 @@ class Transducer(nn.Module):
self.decoder_datatang = decoder_datatang self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang self.joiner_datatang = joiner_datatang
self.simple_am_proj = ScaledLinear( self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None: if decoder_datatang is not None:
@ -179,9 +177,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
boundary = torch.zeros( boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens boundary[:, 3] = encoder_out_lens

View File

@ -87,9 +87,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -115,10 +117,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -165,15 +169,16 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help="Maximum number of symbols per frame. " help=(
"Use only when --method is greedy_search", "Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -196,10 +201,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -257,13 +261,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -311,9 +311,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -330,9 +328,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -96,9 +96,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -224,8 +222,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -248,42 +245,45 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=1, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help="The prune range for rnnt loss, it means how many symbols(context)" help=(
"we are using to compute the loss", "The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help="The scale to smooth the loss with lm " help=(
"(output of prediction network) part.", "The scale to smooth the loss with lm (output of prediction network) part."
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)part.",
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help="To get pruning ranges, we will calculate a simple version" help=(
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss."
),
) )
parser.add_argument( parser.add_argument(
@ -635,11 +635,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -670,23 +666,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -824,9 +813,7 @@ def train_one_epoch(
) )
# summary stats # summary stats
if datatang_train_dl is not None: if datatang_train_dl is not None:
tot_loss = ( tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
if aishell: if aishell:
aishell_tot_loss = ( aishell_tot_loss = (
@ -847,9 +834,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -892,9 +877,7 @@ def train_one_epoch(
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None: if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
tot_loss_str = ( tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
)
else: else:
tot_loss_str = "" tot_loss_str = ""
datatang_str = "" datatang_str = ""
@ -1076,9 +1059,7 @@ def run(rank, world_size, args):
train_cuts = filter_short_and_long_utterances(train_cuts) train_cuts = filter_short_and_long_utterances(train_cuts)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest( cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else: else:
cuts_musan = None cuts_musan = None
@ -1093,9 +1074,7 @@ def run(rank, world_size, args):
if params.datatang_prob > 0: if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts() train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances( train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts
)
train_datatang_cuts = train_datatang_cuts.repeat(times=None) train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders( datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts, train_datatang_cuts,
@ -1249,9 +1228,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -64,10 +64,12 @@ class AishellAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc."
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -79,59 +81,74 @@ class AishellAsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help=(
"single batch. You can reduce it if it causes CUDA OOM.", "Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, the batches will come from buckets of " help=(
"similar duration (saves padding frames).", "When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help="The number of buckets for the DynamicBucketingSampler" help=(
"(you might want to increase it for larger datasets).", "The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, utterances (cuts) will be concatenated " help=(
"to minimize the amount of padding.", "When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help="Determines the maximum duration of a concatenated cut " help=(
"relative to the duration of the longest cut in a batch.", "Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help="The amount of padding (in seconds) inserted between " help=(
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.", "noise augmentation is used."
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, use on-the-fly cut mixing and feature " help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available.", "if available."
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled (=default), the examples will be " help=(
"shuffled for each epoch.", "When enabled (=default), the examples will be shuffled for each epoch."
),
) )
group.add_argument( group.add_argument(
"--drop-last", "--drop-last",
@ -143,17 +160,18 @@ class AishellAsrDataModule:
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, each batch will have the " help=(
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.", "were used to construct it."
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that collect the batches.",
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -167,40 +185,40 @@ class AishellAsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help="Used only when --enable-spec-aug is True. " help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp.", "A value less than 1 means to disable time warp."
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, select noise from MUSAN and mix it" help=(
"with training dataset. ", "When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
) )
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " "Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -215,9 +233,7 @@ class AishellAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -260,9 +276,7 @@ class AishellAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -308,9 +322,7 @@ class AishellAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -366,13 +378,9 @@ class AishellAsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
)

View File

@ -49,16 +49,19 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=19, default=19,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=5, default=5,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
"--method", "--method",
@ -265,9 +268,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -289,9 +290,7 @@ def save_results(
# We compute CER for aishell dataset. # We compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
test_set_wers[key] = wer test_set_wers[key] = wer
@ -335,9 +334,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -362,9 +359,7 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save( torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
model.to(device) model.to(device)
model.eval() model.eval()
@ -392,9 +387,7 @@ def main():
lexicon=lexicon, lexicon=lexicon,
) )
save_results( save_results(params=params, test_set_name=test_set, results_dict=results_dict)
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!") logging.info("Done!")

View File

@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
nn.BatchNorm1d(num_features=500, affine=False), nn.BatchNorm1d(num_features=500, affine=False),
) )
self.lstms = nn.ModuleList( self.lstms = nn.ModuleList(
[ [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
for _ in range(5)
]
) )
self.lstm_bnorms = nn.ModuleList( self.lstm_bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]

View File

@ -41,9 +41,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -53,9 +55,7 @@ def get_parser():
help="Path to words.txt", help="Path to words.txt",
) )
parser.add_argument( parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument( parser.add_argument(
"--method", "--method",
@ -71,10 +71,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
return parser return parser
@ -112,10 +114,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -173,9 +174,7 @@ def main():
logging.info("Decoding started") logging.info("Decoding started")
features = fbank(waves) features = fbank(waves)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.permute(0, 2, 1) # now features is [N, C, T] features = features.permute(0, 2, 1) # now features is [N, C, T]
with torch.no_grad(): with torch.no_grad():
@ -219,9 +218,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():

View File

@ -47,9 +47,9 @@ def greedy_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
[blank_id] * context_size, device=device 1, context_size
).reshape(1, context_size) )
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -81,9 +81,9 @@ def greedy_search(
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y != blank_id:
hyp.append(y) hyp.append(y)
decoder_input = torch.tensor( decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
[hyp[-context_size:]], device=device 1, context_size
).reshape(1, context_size) )
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -157,9 +157,7 @@ class HypothesisList(object):
""" """
if length_norm: if length_norm:
return max( return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else: else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob) return max(self._data.values(), key=lambda hyp: hyp.log_prob)
@ -246,9 +244,9 @@ def beam_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
[blank_id] * context_size, device=device 1, context_size
).reshape(1, context_size) )
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5 self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm( self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout( src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
self.feed_forward_macaron(src)
)
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module):
""" """
def __init__( def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# Suppose `i` means to the position of query vector and `j` means the # Suppose `i` means to the position of query vector and `j` means the
@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear( q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
query, in_proj_weight, in_proj_bias 3, dim=-1
).chunk(3, dim=-1) )
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -701,33 +689,25 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError( raise RuntimeError("The size of the 2D attn_mask is not correct.")
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), query.size(0),
key.size(0), key.size(0),
]: ]:
raise RuntimeError( raise RuntimeError("The size of the 3D attn_mask is not correct.")
"The size of the 3D attn_mask is not correct."
)
else: else:
raise RuntimeError( raise RuntimeError(
"attn_mask's dimension {} is not supported".format( "attn_mask's dimension {} is not supported".format(attn_mask.dim())
attn_mask.dim()
)
) )
# attn_mask's dim is 3 now. # attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool # convert ByteTensor key_padding_mask to bool
if ( if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
" instead."
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
@ -764,9 +744,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
@ -778,9 +756,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [ assert list(attn_output_weights.size()) == [
bsz * num_heads, bsz * num_heads,
@ -814,13 +790,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
) )
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights: if need_weights:
# average attention weights over heads # average attention weights over heads
@ -843,9 +815,7 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding

View File

@ -52,16 +52,19 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=30, default=30,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -99,8 +102,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -227,9 +229,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -248,9 +248,7 @@ def decode_one_batch(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp]) hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
@ -319,9 +317,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -346,9 +342,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -359,8 +353,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -430,9 +423,7 @@ def main():
if params.export: if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save( torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return return
model.to(device) model.to(device)

View File

@ -86,9 +86,7 @@ class Decoder(nn.Module):
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embedding_out = F.pad( embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
embedding_out, pad=(self.context_size - 1, 0)
)
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output

View File

@ -69,17 +69,20 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=20, default=20,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -110,8 +113,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -243,9 +245,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -103,9 +103,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
boundary = torch.zeros( boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens

View File

@ -73,9 +73,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -100,10 +102,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -117,8 +121,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -211,10 +214,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -273,9 +275,7 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -319,9 +319,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -126,8 +126,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -389,9 +388,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -504,9 +501,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -625,9 +620,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu": elif activation == "gelu":
return nn.functional.gelu return nn.functional.gelu
raise RuntimeError( raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):

View File

@ -29,10 +29,7 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
OnTheFlyFeatures,
PrecomputedFeatures,
)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
@ -46,59 +43,69 @@ class AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc."
),
) )
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help=(
"single batch. You can reduce it if it causes CUDA OOM.", "Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, the batches will come from buckets of " help=(
"similar duration (saves padding frames).", "When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help="The number of buckets for the DynamicBucketingSampler " help=(
"(you might want to increase it for larger datasets).", "The number of buckets for the DynamicBucketingSampler "
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled (=default), the examples will be " help=(
"shuffled for each epoch.", "When enabled (=default), the examples will be shuffled for each epoch."
),
) )
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, each batch will have the " help=(
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.", "were used to construct it."
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that collect the batches.",
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -112,18 +119,22 @@ class AsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help="Used only when --enable-spec-aug is True. " help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp.", "A value less than 1 means to disable time warp."
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, select noise from MUSAN and mix it" help=(
"with training dataset. ", "When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
) )
group.add_argument( group.add_argument(
@ -137,9 +148,11 @@ class AsrDataModule:
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, use on-the-fly cut mixing and feature " help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available. Used only in dev/test CutSet", "if available. Used only in dev/test CutSet"
),
) )
def train_dataloaders( def train_dataloaders(
@ -162,9 +175,7 @@ class AsrDataModule:
if cuts_musan is not None: if cuts_musan is not None:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -173,9 +184,7 @@ class AsrDataModule:
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -252,9 +261,7 @@ class AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -93,16 +93,19 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=30, default=30,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -170,8 +173,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -227,9 +229,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -241,10 +241,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -288,11 +285,7 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -365,9 +358,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -393,9 +384,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -406,8 +395,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -448,9 +436,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -68,17 +68,20 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=20, default=20,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -109,8 +112,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -241,9 +243,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -87,9 +87,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -115,10 +117,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -165,15 +169,16 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help="Maximum number of symbols per frame. " help=(
"Use only when --method is greedy_search", "Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
) )
return parser return parser
@ -194,10 +199,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -254,13 +258,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -308,9 +308,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -327,9 +325,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -149,8 +149,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -168,8 +167,7 @@ def get_parser():
"--datatang-prob", "--datatang-prob",
type=float, type=float,
default=0.2, default=0.2,
help="The probability to select a batch from the " help="The probability to select a batch from the aidatatang_200zh dataset",
"aidatatang_200zh dataset",
) )
return parser return parser
@ -449,9 +447,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -605,9 +601,7 @@ def train_one_epoch(
f"train/current_{prefix}_", f"train/current_{prefix}_",
params.batch_idx_train, params.batch_idx_train,
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
aishell_tot_loss.write_summary( aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train tb_writer, "train/aishell_tot_", params.batch_idx_train
) )
@ -735,9 +729,7 @@ def run(rank, world_size, args):
train_datatang_cuts = train_datatang_cuts.repeat(times=None) train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest( cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
)
else: else:
cuts_musan = None cuts_musan = None
@ -776,9 +768,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

View File

@ -94,16 +94,19 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=30, default=30,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -171,8 +174,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -231,9 +233,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -245,10 +245,7 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -292,11 +289,7 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -369,9 +362,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -397,9 +388,7 @@ def save_results(
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append( results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -410,8 +399,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
@ -452,9 +440,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"

View File

@ -68,17 +68,20 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=20, default=20,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=10,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -109,8 +112,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -241,9 +243,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -87,9 +87,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -115,10 +117,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -165,15 +169,16 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=1, default=1,
help="Maximum number of symbols per frame. " help=(
"Use only when --method is greedy_search", "Maximum number of symbols per frame. "
"Use only when --method is greedy_search"
),
) )
return parser return parser
@ -194,10 +199,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -254,13 +258,9 @@ def main():
feature_lens = [f.size(0) for f in features] feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
@ -308,9 +308,7 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.method}")
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [] hyps = []
@ -327,9 +325,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -142,8 +142,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -414,9 +413,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -529,9 +526,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -657,9 +652,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if rank == 0:

0
egs/aishell2/ASR/local/__init__.py Executable file → Normal file
View File

View File

@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

View File

@ -76,10 +76,12 @@ class AiShell2AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc."
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -91,59 +93,74 @@ class AiShell2AsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help=(
"single batch. You can reduce it if it causes CUDA OOM.", "Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, the batches will come from buckets of " help=(
"similar duration (saves padding frames).", "When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help="The number of buckets for the DynamicBucketingSampler" help=(
"(you might want to increase it for larger datasets).", "The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, utterances (cuts) will be concatenated " help=(
"to minimize the amount of padding.", "When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help="Determines the maximum duration of a concatenated cut " help=(
"relative to the duration of the longest cut in a batch.", "Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help="The amount of padding (in seconds) inserted between " help=(
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.", "noise augmentation is used."
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, use on-the-fly cut mixing and feature " help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available.", "if available."
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled (=default), the examples will be " help=(
"shuffled for each epoch.", "When enabled (=default), the examples will be shuffled for each epoch."
),
) )
group.add_argument( group.add_argument(
"--drop-last", "--drop-last",
@ -155,17 +172,18 @@ class AiShell2AsrDataModule:
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, each batch will have the " help=(
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.", "were used to construct it."
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that collect the batches.",
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -179,18 +197,22 @@ class AiShell2AsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help="Used only when --enable-spec-aug is True. " help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp.", "A value less than 1 means to disable time warp."
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, select noise from MUSAN and mix it" help=(
"with training dataset. ", "When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
) )
group.add_argument( group.add_argument(
@ -216,20 +238,16 @@ class AiShell2AsrDataModule:
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " "Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -244,9 +262,7 @@ class AiShell2AsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -290,9 +306,7 @@ class AiShell2AsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -348,9 +362,7 @@ class AiShell2AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -406,9 +418,7 @@ class AiShell2AsrDataModule:
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:

View File

@ -168,20 +168,24 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=True, default=True,
help="Whether to load averaged model. Currently it only supports " help=(
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. "
),
) )
parser.add_argument( parser.add_argument(
@ -269,8 +273,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -348,9 +351,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -409,10 +410,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -538,9 +536,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -573,8 +569,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -625,9 +620,7 @@ def main():
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -661,13 +654,12 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -690,13 +682,12 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -724,7 +715,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
f"Calculating the averaged model over epoch range from " "Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -749,9 +740,7 @@ def main():
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None

View File

@ -89,20 +89,24 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help="Whether to load averaged model. Currently it only supports " help=(
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. "
),
) )
parser.add_argument( parser.add_argument(
@ -133,8 +137,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -167,13 +170,12 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -196,13 +198,12 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -230,7 +231,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
f"Calculating the averaged model over epoch range from " "Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -266,9 +267,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -81,9 +81,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -109,10 +111,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -159,8 +163,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -191,10 +194,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -254,15 +256,11 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
@ -334,9 +332,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -92,9 +92,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -220,8 +218,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -244,42 +241,45 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help="The prune range for rnnt loss, it means how many symbols(context)" help=(
"we are using to compute the loss", "The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help="The scale to smooth the loss with lm " help=(
"(output of prediction network) part.", "The scale to smooth the loss with lm (output of prediction network) part."
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)part.",
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help="To get pruning ranges, we will calculate a simple version" help=(
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss."
),
) )
parser.add_argument( parser.add_argument(
@ -603,11 +603,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -636,23 +632,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -771,9 +760,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -829,9 +816,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -1104,9 +1089,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch( display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
batch, params=params, graph_compiler=graph_compiler
)
raise raise

View File

@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
@ -120,9 +118,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -50,15 +50,15 @@ def get_parser():
"-n", "-n",
default=1, default=1,
type=int, type=int,
help="number of characters to split, i.e., \ help=(
aabb -> a a b b with -n 1 and aa bb with -n 2", "number of characters to split, i.e., aabb -> a a b"
" b with -n 1 and aa bb with -n 2"
),
) )
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument( parser.add_argument("--space", default="<space>", type=str, help="space symbol")
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,9 +66,7 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument( parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -108,8 +106,7 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id token_table[txt] if txt in token_table else oov_id for txt in text
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -135,9 +132,7 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")( f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -74,10 +74,12 @@ class Aishell4AsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc."
),
) )
group.add_argument( group.add_argument(
@ -91,66 +93,81 @@ class Aishell4AsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help=(
"single batch. You can reduce it if it causes CUDA OOM.", "Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, the batches will come from buckets of " help=(
"similar duration (saves padding frames).", "When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=300,
help="The number of buckets for the DynamicBucketingSampler" help=(
"(you might want to increase it for larger datasets).", "The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, utterances (cuts) will be concatenated " help=(
"to minimize the amount of padding.", "When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help="Determines the maximum duration of a concatenated cut " help=(
"relative to the duration of the longest cut in a batch.", "Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help="The amount of padding (in seconds) inserted between " help=(
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.", "noise augmentation is used."
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, use on-the-fly cut mixing and feature " help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available.", "if available."
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled (=default), the examples will be " help=(
"shuffled for each epoch.", "When enabled (=default), the examples will be shuffled for each epoch."
),
) )
group.add_argument( group.add_argument(
@ -164,17 +181,18 @@ class Aishell4AsrDataModule:
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, each batch will have the " help=(
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.", "were used to construct it."
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that collect the batches.",
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -188,18 +206,22 @@ class Aishell4AsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help="Used only when --enable-spec-aug is True. " help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp.", "A value less than 1 means to disable time warp."
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, select noise from MUSAN and mix it" help=(
"with training dataset. ", "When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
) )
group.add_argument( group.add_argument(
@ -222,24 +244,20 @@ class Aishell4AsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " "Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -254,9 +272,7 @@ class Aishell4AsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -300,9 +316,7 @@ class Aishell4AsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -359,9 +373,7 @@ class Aishell4AsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -117,20 +117,24 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help="Whether to load averaged model. Currently it only supports " help=(
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. "
),
) )
parser.add_argument( parser.add_argument(
@ -201,8 +205,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -260,9 +263,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -277,10 +278,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -326,11 +324,7 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -401,9 +395,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -436,8 +428,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -480,9 +471,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -510,13 +499,12 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -543,13 +531,12 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -578,7 +565,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
f"Calculating the averaged model over epoch range from " "Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)

View File

@ -89,20 +89,24 @@ def get_parser():
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'"
),
) )
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help="Whether to load averaged model. Currently it only supports " help=(
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and " "Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. "
),
) )
parser.add_argument( parser.add_argument(
@ -136,8 +140,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -169,13 +172,12 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -202,13 +204,12 @@ def main():
) )
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -237,7 +238,7 @@ def main():
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
f"Calculating the averaged model over epoch range from " "Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
@ -276,9 +277,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -94,9 +94,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -122,10 +124,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -172,8 +176,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -204,10 +207,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -266,15 +268,11 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
@ -306,10 +304,7 @@ def main():
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -350,9 +345,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -85,9 +85,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
@ -213,8 +211,7 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need " help="The initial learning rate. This value should not need to be changed.",
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -237,42 +234,45 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help="The prune range for rnnt loss, it means how many symbols(context)" help=(
"we are using to compute the loss", "The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help="The scale to smooth the loss with lm " help=(
"(output of prediction network) part.", "The scale to smooth the loss with lm (output of prediction network) part."
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)part.",
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help="To get pruning ranges, we will calculate a simple version" help=(
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss."
),
) )
parser.add_argument( parser.add_argument(
@ -599,11 +599,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -633,22 +629,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -827,9 +816,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")

View File

@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
) )
cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = num_jobs if ex is None else 80
cur_num_jobs = min(cur_num_jobs, len(cut_set)) cur_num_jobs = min(cur_num_jobs, len(cut_set))
@ -121,9 +119,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [ pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False return False
def generate_lexicon( def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table. """Generate a lexicon from a word list and token_sym_table.
Args: Args:

View File

@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args() return parser.parse_args()

View File

@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L") fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst( fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig") fsa_disambig.draw("L_disambig.pdf", title="L_disambig")

View File

@ -30,8 +30,8 @@ with word segmenting:
import argparse import argparse
import paddle
import jieba import jieba
import paddle
from tqdm import tqdm from tqdm import tqdm
paddle.enable_static() paddle.enable_static()

View File

@ -50,15 +50,15 @@ def get_parser():
"-n", "-n",
default=1, default=1,
type=int, type=int,
help="number of characters to split, i.e., \ help=(
aabb -> a a b b with -n 1 and aa bb with -n 2", "number of characters to split, i.e., aabb -> a a b"
" b with -n 1 and aa bb with -n 2"
),
) )
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument( parser.add_argument("--space", default="<space>", type=str, help="space symbol")
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -66,9 +66,7 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument( parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -108,8 +106,7 @@ def token2id(
if token_type == "lazy_pinyin": if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list) text = lazy_pinyin(chars_list)
sub_ids = [ sub_ids = [
token_table[txt] if txt in token_table else oov_id token_table[txt] if txt in token_table else oov_id for txt in text
for txt in text
] ]
ids.append(sub_ids) ids.append(sub_ids)
else: # token_type = "pinyin" else: # token_type = "pinyin"
@ -135,9 +132,7 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")( f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -81,10 +81,12 @@ class AlimeetingAsrDataModule:
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc."
),
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
@ -96,75 +98,91 @@ class AlimeetingAsrDataModule:
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help=(
"single batch. You can reduce it if it causes CUDA OOM.", "Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, the batches will come from buckets of " help=(
"similar duration (saves padding frames).", "When enabled, the batches will come from buckets of "
"similar duration (saves padding frames)."
),
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=300, default=300,
help="The number of buckets for the DynamicBucketingSampler" help=(
"(you might want to increase it for larger datasets).", "The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets)."
),
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, utterances (cuts) will be concatenated " help=(
"to minimize the amount of padding.", "When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
default=1.0, default=1.0,
help="Determines the maximum duration of a concatenated cut " help=(
"relative to the duration of the longest cut in a batch.", "Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
default=1.0, default=1.0,
help="The amount of padding (in seconds) inserted between " help=(
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.", "noise augmentation is used."
),
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, use on-the-fly cut mixing and feature " help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available.", "if available."
),
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled (=default), the examples will be " help=(
"shuffled for each epoch.", "When enabled (=default), the examples will be shuffled for each epoch."
),
) )
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, each batch will have the " help=(
"When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that " "field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.", "were used to construct it."
),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that collect the batches.",
"collect the batches.",
) )
group.add_argument( group.add_argument(
@ -178,18 +196,22 @@ class AlimeetingAsrDataModule:
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
default=80, default=80,
help="Used only when --enable-spec-aug is True. " help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. " "It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. " "Larger values mean more warping. "
"A value less than 1 means to disable time warp.", "A value less than 1 means to disable time warp."
),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, select noise from MUSAN and mix it" help=(
"with training dataset. ", "When enabled, select noise from MUSAN and mix it"
"with training dataset. "
),
) )
def train_dataloaders( def train_dataloaders(
@ -205,24 +227,20 @@ class AlimeetingAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " "Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}." f"{self.args.duration_factor} and gap {self.args.gap}."
) )
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
@ -237,9 +255,7 @@ class AlimeetingAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -282,9 +298,7 @@ class AlimeetingAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -341,9 +355,7 @@ class AlimeetingAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:

View File

@ -70,11 +70,7 @@ from beam_search import (
from lhotse.cut import Cut from lhotse.cut import Cut
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -93,25 +89,30 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--batch", "--batch",
type=int, type=int,
default=None, default=None,
help="It specifies the batch checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -193,8 +194,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -249,9 +249,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -266,10 +264,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -315,11 +310,7 @@ def decode_one_batch(
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -390,9 +381,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -425,8 +414,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -563,8 +551,7 @@ def main():
) )
dev_shards = [ dev_shards = [
str(path) str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
] ]
cuts_dev_webdataset = CutSet.from_webdataset( cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards, dev_shards,
@ -574,8 +561,7 @@ def main():
) )
test_shards = [ test_shards = [
str(path) str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
] ]
cuts_test_webdataset = CutSet.from_webdataset( cuts_test_webdataset = CutSet.from_webdataset(
test_shards, test_shards,
@ -588,9 +574,7 @@ def main():
return 1.0 <= c.duration return 1.0 <= c.duration
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
cuts_test_webdataset = cuts_test_webdataset.filter( cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
remove_short_and_long_utt
)
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)

View File

@ -62,17 +62,20 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help=(
"Note: Epoch counts from 0.", "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
),
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. "
),
) )
parser.add_argument( parser.add_argument(
@ -103,8 +106,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
return parser return parser
@ -173,9 +175,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -85,9 +85,11 @@ def get_parser():
"--checkpoint", "--checkpoint",
type=str, type=str,
required=True, required=True,
help="Path to the checkpoint. " help=(
"Path to the checkpoint. "
"The checkpoint is assumed to be saved by " "The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().", "icefall.checkpoint.save_checkpoint()."
),
) )
parser.add_argument( parser.add_argument(
@ -112,10 +114,12 @@ def get_parser():
"sound_files", "sound_files",
type=str, type=str,
nargs="+", nargs="+",
help="The input sound file(s) to transcribe. " help=(
"The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz."
),
) )
parser.add_argument( parser.add_argument(
@ -162,8 +166,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -193,10 +196,9 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, ( assert (
f"expected sample rate: {expected_sample_rate}. " sample_rate == expected_sample_rate
f"Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
)
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0])
return ans return ans
@ -257,9 +259,7 @@ def main():
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
features = pad_sequence( features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
@ -284,10 +284,7 @@ def main():
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -339,9 +336,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -81,9 +81,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@ -187,42 +185,45 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=5, default=5,
help="The prune range for rnnt loss, it means how many symbols(context)" help=(
"we are using to compute the loss", "The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss"
),
) )
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.25, default=0.25,
help="The scale to smooth the loss with lm " help=(
"(output of prediction network) part.", "The scale to smooth the loss with lm (output of prediction network) part."
),
) )
parser.add_argument( parser.add_argument(
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)part.",
"part.",
) )
parser.add_argument( parser.add_argument(
"--simple-loss-scale", "--simple-loss-scale",
type=float, type=float,
default=0.5, default=0.5,
help="To get pruning ranges, we will calculate a simple version" help=(
"To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for" "loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss" "training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss."
),
) )
parser.add_argument( parser.add_argument(
@ -542,22 +543,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -711,9 +705,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")

View File

@ -25,15 +25,10 @@ from random import Random
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from lhotse import ( from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
CutSet, CutSet,
Fbank, Fbank,
FbankConfig, FbankConfig,
# fmt: off
# See the following for why LilcomChunkyWriter is preferred
# https://github.com/k2-fsa/icefall/pull/404
# https://github.com/lhotse-speech/lhotse/pull/527
# fmt: on
LilcomChunkyWriter, LilcomChunkyWriter,
RecordingSet, RecordingSet,
SupervisionSet, SupervisionSet,
@ -81,17 +76,13 @@ def make_cutset_blueprints(
cut_sets.append((f"eval{i}", cut_set)) cut_sets.append((f"eval{i}", cut_set))
# Create train and valid cuts # Create train and valid cuts
logging.info( logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
"Loading, trimming, and shuffling the remaining core+noncore cuts."
)
recording_set = RecordingSet.from_file( recording_set = RecordingSet.from_file(
manifest_dir / "csj_recordings_core.jsonl.gz" manifest_dir / "csj_recordings_core.jsonl.gz"
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
supervision_set = SupervisionSet.from_file( supervision_set = SupervisionSet.from_file(
manifest_dir / "csj_supervisions_core.jsonl.gz" manifest_dir / "csj_supervisions_core.jsonl.gz"
) + SupervisionSet.from_file( ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
manifest_dir / "csj_supervisions_noncore.jsonl.gz"
)
cut_set = CutSet.from_manifests( cut_set = CutSet.from_manifests(
recordings=recording_set, recordings=recording_set,
@ -101,15 +92,12 @@ def make_cutset_blueprints(
cut_set = cut_set.shuffle(Random(RNG_SEED)) cut_set = cut_set.shuffle(Random(RNG_SEED))
logging.info( logging.info(
"Creating valid and train cuts from core and noncore," f"Creating valid and train cuts from core and noncore,split at {split}."
f"split at {split}."
) )
valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
train_set = CutSet.from_cuts(islice(cut_set, split, None)) train_set = CutSet.from_cuts(islice(cut_set, split, None))
train_set = ( train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
)
cut_sets.extend([("valid", valid_set), ("train", train_set)]) cut_sets.extend([("valid", valid_set), ("train", train_set)])
@ -122,15 +110,9 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
"--manifest-dir", type=Path, help="Path to save manifests" parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
) parser.add_argument("--split", type=int, default=4000, help="Split at this index")
parser.add_argument(
"--fbank-dir", type=Path, help="Path to save fbank features"
)
parser.add_argument(
"--split", type=int, default=4000, help="Split at this index"
)
return parser.parse_args() return parser.parse_args()
@ -141,9 +123,7 @@ def main():
extractor = Fbank(FbankConfig(num_mel_bins=80)) extractor = Fbank(FbankConfig(num_mel_bins=80))
num_jobs = min(16, os.cpu_count()) num_jobs = min(16, os.cpu_count())
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -26,7 +26,6 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
ARGPARSE_DESCRIPTION = """ ARGPARSE_DESCRIPTION = """
This file computes fbank features of the musan dataset. This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests. It looks for manifests in the directory data/manifests.
@ -84,9 +83,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
# create chunks of Musan with duration 5 - 10 seconds # create chunks of Musan with duration 5 - 10 seconds
musan_cuts = ( musan_cuts = (
CutSet.from_manifests( CutSet.from_manifests(
recordings=combine( recordings=combine(part["recordings"] for part in manifests.values())
part["recordings"] for part in manifests.values()
)
) )
.cut_into_windows(10.0) .cut_into_windows(10.0)
.filter(lambda c: c.duration > 5) .filter(lambda c: c.duration > 5)
@ -107,21 +104,15 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
"--manifest-dir", type=Path, help="Path to save manifests" parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
)
parser.add_argument(
"--fbank-dir", type=Path, help="Path to save fbank features"
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan(args.manifest_dir, args.fbank_dir) compute_fbank_musan(args.manifest_dir, args.fbank_dir)

View File

@ -318,4 +318,3 @@ spk_id = 2
= ǐa = ǐa
= ǐu = ǐu
= ǐo = ǐo

View File

@ -318,4 +318,3 @@ spk_id = 2
= ǐa = ǐa
= ǐu = ǐu
= ǐo = ǐo

View File

@ -318,4 +318,3 @@ spk_id = 2
= ǐa = ǐa
= ǐu = ǐu
= ǐo = ǐo

Some files were not shown because too many files have changed in this diff Show More