mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Revert "Apply new Black style changes"
This commit is contained in:
parent
a7fbb18bdc
commit
60317120ca
@ -1,2 +0,0 @@
|
|||||||
# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
|
|
||||||
d110b04ad389134c82fa314e3aafc7b40043efb0
|
|
11
.github/workflows/style_check.yml
vendored
11
.github/workflows/style_check.yml
vendored
@ -45,18 +45,17 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
|
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
|
||||||
# Click issue fixed in https://github.com/psf/black/pull/2966
|
# See https://github.com/psf/black/issues/2964
|
||||||
|
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
|
||||||
|
|
||||||
- name: Run flake8
|
- name: Run flake8
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: ${{github.workspace}}
|
working-directory: ${{github.workspace}}
|
||||||
run: |
|
run: |
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
flake8 . --count --show-source --statistics
|
||||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
flake8 .
|
||||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
|
|
||||||
--statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
|
|
||||||
|
|
||||||
- name: Run black
|
- name: Run black
|
||||||
shell: bash
|
shell: bash
|
||||||
|
@ -1,38 +1,26 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 22.3.0
|
rev: 21.6b0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
args: ["--line-length=88"]
|
args: [--line-length=80]
|
||||||
additional_dependencies: ['click==8.1.0']
|
additional_dependencies: ['click==8.0.1']
|
||||||
exclude: icefall\/__init__\.py
|
exclude: icefall\/__init__\.py
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 5.0.4
|
rev: 3.9.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
|
args: [--max-line-length=80]
|
||||||
|
|
||||||
# What are we ignoring here?
|
|
||||||
# E203: whitespace before ':'
|
|
||||||
# E266: too many leading '#' for block comment
|
|
||||||
# E501: line too long
|
|
||||||
# F401: module imported but unused
|
|
||||||
# E402: module level import not at top of file
|
|
||||||
# F403: 'from module import *' used; unable to detect undefined names
|
|
||||||
# F841: local variable is assigned to but never used
|
|
||||||
# W503: line break before binary operator
|
|
||||||
# In addition, the default ignore list is:
|
|
||||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.10.1
|
rev: 5.9.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args: ["--profile=black"]
|
args: [--profile=black, --line-length=80]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.2.0
|
rev: v4.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-executables-have-shebangs
|
- id: check-executables-have-shebangs
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
@ -88,3 +88,4 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
|||||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
WORKDIR /workspace/icefall
|
WORKDIR /workspace/icefall
|
||||||
|
|
||||||
|
@ -19,3 +19,4 @@ It can be downloaded from `<https://www.openslr.org/33/>`_
|
|||||||
tdnn_lstm_ctc
|
tdnn_lstm_ctc
|
||||||
conformer_ctc
|
conformer_ctc
|
||||||
stateless_transducer
|
stateless_transducer
|
||||||
|
|
||||||
|
@ -6,3 +6,4 @@ TIMIT
|
|||||||
|
|
||||||
tdnn_ligru_ctc
|
tdnn_ligru_ctc
|
||||||
tdnn_lstm_ctc
|
tdnn_lstm_ctc
|
||||||
|
|
||||||
|
@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set
|
||||||
|
+ cut_set.perturb_speed(0.9)
|
||||||
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
@ -114,7 +116,9 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
|
|||||||
cur_state = loop_state
|
cur_state = loop_state
|
||||||
|
|
||||||
word = word2id[word]
|
word = word2id[word]
|
||||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
pieces = [
|
||||||
|
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(len(pieces) - 1):
|
for i in range(len(pieces) - 1):
|
||||||
w = word if i == 0 else eps
|
w = word if i == 0 else eps
|
||||||
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
def generate_lexicon(
|
||||||
|
token_sym_table: Dict[str, int], words: List[str]
|
||||||
|
) -> Lexicon:
|
||||||
"""Generate a lexicon from a word list and token_sym_table.
|
"""Generate a lexicon from a word list and token_sym_table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -317,7 +317,9 @@ def lexicon_to_fst(
|
|||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
parser.add_argument(
|
||||||
|
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
|
|||||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa.draw("L.pdf", title="L")
|
fsa.draw("L.pdf", title="L")
|
||||||
|
|
||||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
fsa_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||||
|
)
|
||||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||||
|
@ -50,15 +50,15 @@ def get_parser():
|
|||||||
"-n",
|
"-n",
|
||||||
default=1,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help=(
|
help="number of characters to split, i.e., \
|
||||||
"number of characters to split, i.e., aabb -> a a b"
|
aabb -> a a b b with -n 1 and aa bb with -n 2",
|
||||||
" b with -n 1 and aa bb with -n 2"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||||
)
|
)
|
||||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
parser.add_argument(
|
||||||
|
"--space", default="<space>", type=str, help="space symbol"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--non-lang-syms",
|
"--non-lang-syms",
|
||||||
"-l",
|
"-l",
|
||||||
@ -66,7 +66,9 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||||
)
|
)
|
||||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
parser.add_argument(
|
||||||
|
"text", type=str, default=False, nargs="?", help="input text"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trans_type",
|
"--trans_type",
|
||||||
"-t",
|
"-t",
|
||||||
@ -106,7 +108,8 @@ def token2id(
|
|||||||
if token_type == "lazy_pinyin":
|
if token_type == "lazy_pinyin":
|
||||||
text = lazy_pinyin(chars_list)
|
text = lazy_pinyin(chars_list)
|
||||||
sub_ids = [
|
sub_ids = [
|
||||||
token_table[txt] if txt in token_table else oov_id for txt in text
|
token_table[txt] if txt in token_table else oov_id
|
||||||
|
for txt in text
|
||||||
]
|
]
|
||||||
ids.append(sub_ids)
|
ids.append(sub_ids)
|
||||||
else: # token_type = "pinyin"
|
else: # token_type = "pinyin"
|
||||||
@ -132,7 +135,9 @@ def main():
|
|||||||
if args.text:
|
if args.text:
|
||||||
f = codecs.open(args.text, encoding="utf-8")
|
f = codecs.open(args.text, encoding="utf-8")
|
||||||
else:
|
else:
|
||||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
f = codecs.getreader("utf-8")(
|
||||||
|
sys.stdin if is_python2 else sys.stdin.buffer
|
||||||
|
)
|
||||||
|
|
||||||
sys.stdout = codecs.getwriter("utf-8")(
|
sys.stdout = codecs.getwriter("utf-8")(
|
||||||
sys.stdout if is_python2 else sys.stdout.buffer
|
sys.stdout if is_python2 else sys.stdout.buffer
|
||||||
|
@ -113,3 +113,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
./local/prepare_char.py
|
./local/prepare_char.py
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -81,12 +81,10 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="ASR data related options",
|
||||||
description=(
|
description="These options are used for the preparation of "
|
||||||
"These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc."
|
"augmentations, etc.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
@ -98,91 +96,75 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
default=200.0,
|
default=200.0,
|
||||||
help=(
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"Maximum pooled recordings duration (seconds) in a "
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
"single batch. You can reduce it if it causes CUDA OOM."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, the batches will come from buckets of "
|
||||||
"When enabled, the batches will come from buckets of "
|
"similar duration (saves padding frames).",
|
||||||
"similar duration (saves padding frames)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=300,
|
default=300,
|
||||||
help=(
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"The number of buckets for the DynamicBucketingSampler"
|
"(you might want to increase it for larger datasets).",
|
||||||
"(you might want to increase it for larger datasets)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
"When enabled, utterances (cuts) will be concatenated "
|
"to minimize the amount of padding.",
|
||||||
"to minimize the amount of padding."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--duration-factor",
|
"--duration-factor",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
"Determines the maximum duration of a concatenated cut "
|
"relative to the duration of the longest cut in a batch.",
|
||||||
"relative to the duration of the longest cut in a batch."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gap",
|
"--gap",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="The amount of padding (in seconds) inserted between "
|
||||||
"The amount of padding (in seconds) inserted between "
|
|
||||||
"concatenated cuts. This padding is filled with noise when "
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
"noise augmentation is used."
|
"noise augmentation is used.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
"When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available."
|
"if available.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled (=default), the examples will be "
|
||||||
"When enabled (=default), the examples will be shuffled for each epoch."
|
"shuffled for each epoch.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, each batch will have the "
|
||||||
"When enabled, each batch will have the "
|
|
||||||
"field: batch['supervisions']['cut'] with the cuts that "
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
"were used to construct it."
|
"were used to construct it.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of training dataloader workers that collect the batches.",
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -196,22 +178,18 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
"--spec-aug-time-warp-factor",
|
"--spec-aug-time-warp-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=80,
|
default=80,
|
||||||
help=(
|
help="Used only when --enable-spec-aug is True. "
|
||||||
"Used only when --enable-spec-aug is True. "
|
|
||||||
"It specifies the factor for time warping in SpecAugment. "
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
"Larger values mean more warping. "
|
"Larger values mean more warping. "
|
||||||
"A value less than 1 means to disable time warp."
|
"A value less than 1 means to disable time warp.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--enable-musan",
|
"--enable-musan",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"When enabled, select noise from MUSAN and mix it"
|
"with training dataset. ",
|
||||||
"with training dataset. "
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
@ -227,20 +205,24 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
)
|
)
|
||||||
# Cut concatenation should be the first transform in the list,
|
# Cut concatenation should be the first transform in the list,
|
||||||
@ -255,7 +237,9 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
# Set the value of num_frame_masks according to Lhotse's version.
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
# In different Lhotse's versions, the default of num_frame_masks is
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
# different.
|
# different.
|
||||||
@ -298,7 +282,9 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -354,7 +340,9 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -69,7 +69,11 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -88,30 +92,25 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch",
|
"--batch",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help=(
|
help="It specifies the batch checkpoint to use for decoding."
|
||||||
"It specifies the batch checkpoint to use for decoding."
|
"Note: Epoch counts from 0.",
|
||||||
"Note: Epoch counts from 0."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -193,7 +192,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -249,7 +249,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -264,7 +266,10 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -310,7 +315,11 @@ def decode_one_batch(
|
|||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
@ -381,7 +390,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -414,7 +425,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
|
@ -62,20 +62,17 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -106,7 +103,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -175,7 +173,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -85,11 +85,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -114,12 +112,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -166,7 +162,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -196,9 +193,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -259,7 +257,9 @@ def main():
|
|||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
@ -284,7 +284,10 @@ def main():
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -336,7 +339,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -81,7 +81,9 @@ from icefall.env import get_env_info
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[
|
||||||
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
|
]
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
|
|
||||||
@ -185,45 +187,42 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=(
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"The prune range for rnnt loss, it means how many symbols(context)"
|
"we are using to compute the loss",
|
||||||
"we are using to compute the loss"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.25,
|
default=0.25,
|
||||||
help=(
|
help="The scale to smooth the loss with lm "
|
||||||
"The scale to smooth the loss with lm (output of prediction network) part."
|
"(output of prediction network) part.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)part.",
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--simple-loss-scale",
|
"--simple-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help=(
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
"To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
"with this parameter before adding to the final loss."
|
"with this parameter before adding to the final loss.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -543,15 +542,22 @@ def compute_loss(
|
|||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
0.0
|
||||||
|
if warmup < 1.0
|
||||||
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -705,7 +711,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
|
@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
|
d_model, nhead, dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
self.norm_ff_macaron = nn.LayerNorm(
|
||||||
|
d_model
|
||||||
|
) # for the macaron style FNN module
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
|
|
||||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
self.norm_final = nn.LayerNorm(
|
||||||
|
d_model
|
||||||
|
) # for the final output of the block
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
|
src = residual + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(src)
|
||||||
|
)
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
def __init__(
|
||||||
|
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||||
|
) -> None:
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
|
x.device
|
||||||
|
):
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
# Suppose `i` means to the position of query vector and `j` means the
|
# Suppose `i` means to the position of query vector and `j` means the
|
||||||
@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
q, k, v = nn.functional.linear(
|
||||||
3, dim=-1
|
query, in_proj_weight, in_proj_bias
|
||||||
)
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
raise RuntimeError(
|
||||||
|
"The size of the 2D attn_mask is not correct."
|
||||||
|
)
|
||||||
elif attn_mask.dim() == 3:
|
elif attn_mask.dim() == 3:
|
||||||
if list(attn_mask.size()) != [
|
if list(attn_mask.size()) != [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
query.size(0),
|
query.size(0),
|
||||||
key.size(0),
|
key.size(0),
|
||||||
]:
|
]:
|
||||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
raise RuntimeError(
|
||||||
|
"The size of the 3D attn_mask is not correct."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
"attn_mask's dimension {} is not supported".format(
|
||||||
|
attn_mask.dim()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# attn_mask's dim is 3 now.
|
# attn_mask's dim is 3 now.
|
||||||
|
|
||||||
# convert ByteTensor key_padding_mask to bool
|
# convert ByteTensor key_padding_mask to bool
|
||||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
if (
|
||||||
|
key_padding_mask is not None
|
||||||
|
and key_padding_mask.dtype == torch.uint8
|
||||||
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor"
|
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||||
" instead."
|
|
||||||
)
|
)
|
||||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||||
|
|
||||||
@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# first compute matrix a and matrix c
|
# first compute matrix a and matrix c
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
matrix_ac = torch.matmul(
|
||||||
|
q_with_bias_u, k
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(
|
matrix_bd = torch.matmul(
|
||||||
@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) * scaling # (batch, head, time1, time2)
|
) * scaling # (batch, head, time1, time2)
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
assert list(attn_output_weights.size()) == [
|
assert list(attn_output_weights.size()) == [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||||
attn_output = (
|
attn_output = (
|
||||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
attn_output.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(tgt_len, bsz, embed_dim)
|
||||||
|
)
|
||||||
|
attn_output = nn.functional.linear(
|
||||||
|
attn_output, out_proj_weight, out_proj_bias
|
||||||
)
|
)
|
||||||
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
|
||||||
|
|
||||||
if need_weights:
|
if need_weights:
|
||||||
# average attention weights over heads
|
# average attention weights over heads
|
||||||
@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
|
def __init__(
|
||||||
|
self, channels: int, kernel_size: int, bias: bool = True
|
||||||
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
@ -58,19 +58,16 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=49,
|
default=49,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=20,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -404,7 +401,9 @@ def decode_dataset(
|
|||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -432,7 +431,9 @@ def save_results(
|
|||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||||
@ -440,7 +441,9 @@ def save_results(
|
|||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info(
|
||||||
|
"Wrote detailed error stats to {}".format(errs_filename)
|
||||||
|
)
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
|
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
|
||||||
@ -559,7 +562,9 @@ def main():
|
|||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
save_results(
|
||||||
|
params=params, test_set_name=test_set, results_dict=results_dict
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
@ -40,20 +40,17 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=84,
|
default=84,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=25,
|
default=25,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -160,7 +157,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -46,29 +46,27 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens-file",
|
"--tokens-file",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to tokens.txtUsed only when method is ctc-decoding",
|
help="Path to tokens.txt" "Used only when method is ctc-decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--words-file",
|
"--words-file",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to words.txtUsed when method is NOT ctc-decoding",
|
help="Path to words.txt" "Used when method is NOT ctc-decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--HLG",
|
"--HLG",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to HLG.pt.Used when method is NOT ctc-decoding",
|
help="Path to HLG.pt." "Used when method is NOT ctc-decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -165,12 +163,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -214,9 +210,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -277,7 +274,9 @@ def main():
|
|||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We don't use key padding mask for attention during decoding
|
# Note: We don't use key padding mask for attention during decoding
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -372,7 +371,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
assert idim >= 7
|
assert idim >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
|
nn.Conv2d(
|
||||||
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
|
nn.Conv2d(
|
||||||
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
layers.append(
|
layers.append(
|
||||||
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
torch.nn.MaxPool2d(
|
||||||
|
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
cur_channels = block_dim
|
cur_channels = block_dim
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = nn.Linear(
|
||||||
|
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
|
@ -16,8 +16,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from subsampling import Conv2dSubsampling
|
||||||
|
from subsampling import VggSubsampling
|
||||||
import torch
|
import torch
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
|
||||||
|
|
||||||
|
|
||||||
def test_conv2d_subsampling():
|
def test_conv2d_subsampling():
|
||||||
|
@ -382,7 +382,9 @@ def compute_loss(
|
|||||||
#
|
#
|
||||||
# See https://github.com/k2-fsa/icefall/issues/97
|
# See https://github.com/k2-fsa/icefall/issues/97
|
||||||
# for more details
|
# for more details
|
||||||
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
unsorted_token_ids = graph_compiler.texts_to_ids(
|
||||||
|
supervisions["text"]
|
||||||
|
)
|
||||||
att_loss = mmodel.decoder_forward(
|
att_loss = mmodel.decoder_forward(
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
@ -518,7 +520,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -626,7 +630,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -149,7 +149,9 @@ class Transformer(nn.Module):
|
|||||||
norm=decoder_norm,
|
norm=decoder_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
|
self.decoder_output_layer = torch.nn.Linear(
|
||||||
|
d_model, self.decoder_num_class
|
||||||
|
)
|
||||||
|
|
||||||
self.decoder_criterion = LabelSmoothingLoss()
|
self.decoder_criterion = LabelSmoothingLoss()
|
||||||
else:
|
else:
|
||||||
@ -181,7 +183,9 @@ class Transformer(nn.Module):
|
|||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
|
x, supervision
|
||||||
|
)
|
||||||
x = self.ctc_output(encoder_memory)
|
x = self.ctc_output(encoder_memory)
|
||||||
return x, encoder_memory, memory_key_padding_mask
|
return x, encoder_memory, memory_key_padding_mask
|
||||||
|
|
||||||
@ -262,17 +266,23 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
ys_in_pad = pad_sequence(
|
||||||
|
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||||
|
)
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
ys_out_pad = pad_sequence(
|
||||||
|
ys_out, batch_first=True, padding_value=float(-1)
|
||||||
|
)
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device)
|
ys_in_pad = ys_in_pad.to(device)
|
||||||
ys_out_pad = ys_out_pad.to(device)
|
ys_out_pad = ys_out_pad.to(device)
|
||||||
|
|
||||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
# TODO: Use length information to create the decoder padding mask
|
# TODO: Use length information to create the decoder padding mask
|
||||||
@ -333,17 +343,23 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
ys_in_pad = pad_sequence(
|
||||||
|
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||||
|
)
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
ys_out_pad = pad_sequence(
|
||||||
|
ys_out, batch_first=True, padding_value=float(-1)
|
||||||
|
)
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||||
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
||||||
|
|
||||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
# TODO: Use length information to create the decoder padding mask
|
# TODO: Use length information to create the decoder padding mask
|
||||||
@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
|
|||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
return nn.functional.gelu
|
return nn.functional.gelu
|
||||||
|
|
||||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
raise RuntimeError(
|
||||||
|
"activation should be relu/gelu, not {}".format(activation)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
@ -818,7 +836,9 @@ def encoder_padding_mask(
|
|||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
).to(torch.int32)
|
||||||
|
|
||||||
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
|
lengths = [
|
||||||
|
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
|
||||||
|
]
|
||||||
for idx in range(supervision_segments.size(0)):
|
for idx in range(supervision_segments.size(0)):
|
||||||
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
||||||
sequence_idx = supervision_segments[idx, 0].item()
|
sequence_idx = supervision_segments[idx, 0].item()
|
||||||
@ -839,7 +859,9 @@ def encoder_padding_mask(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
|
def decoder_padding_mask(
|
||||||
|
ys_pad: torch.Tensor, ignore_id: int = -1
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Generate a length mask for input.
|
"""Generate a length mask for input.
|
||||||
|
|
||||||
The masked position are filled with True,
|
The masked position are filled with True,
|
||||||
|
@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
|
d_model, nhead, dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
self.norm_ff_macaron = nn.LayerNorm(
|
||||||
|
d_model
|
||||||
|
) # for the macaron style FNN module
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
|
|
||||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
self.norm_final = nn.LayerNorm(
|
||||||
|
d_model
|
||||||
|
) # for the final output of the block
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
|
src = residual + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(src)
|
||||||
|
)
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
def __init__(
|
||||||
|
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||||
|
) -> None:
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
|
x.device
|
||||||
|
):
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
# Suppose `i` means to the position of query vector and `j` means the
|
# Suppose `i` means to the position of query vector and `j` means the
|
||||||
@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
q, k, v = nn.functional.linear(
|
||||||
3, dim=-1
|
query, in_proj_weight, in_proj_bias
|
||||||
)
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
raise RuntimeError(
|
||||||
|
"The size of the 2D attn_mask is not correct."
|
||||||
|
)
|
||||||
elif attn_mask.dim() == 3:
|
elif attn_mask.dim() == 3:
|
||||||
if list(attn_mask.size()) != [
|
if list(attn_mask.size()) != [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
query.size(0),
|
query.size(0),
|
||||||
key.size(0),
|
key.size(0),
|
||||||
]:
|
]:
|
||||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
raise RuntimeError(
|
||||||
|
"The size of the 3D attn_mask is not correct."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
"attn_mask's dimension {} is not supported".format(
|
||||||
|
attn_mask.dim()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# attn_mask's dim is 3 now.
|
# attn_mask's dim is 3 now.
|
||||||
|
|
||||||
# convert ByteTensor key_padding_mask to bool
|
# convert ByteTensor key_padding_mask to bool
|
||||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
if (
|
||||||
|
key_padding_mask is not None
|
||||||
|
and key_padding_mask.dtype == torch.uint8
|
||||||
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor"
|
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||||
" instead."
|
|
||||||
)
|
)
|
||||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||||
|
|
||||||
@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# first compute matrix a and matrix c
|
# first compute matrix a and matrix c
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
matrix_ac = torch.matmul(
|
||||||
|
q_with_bias_u, k
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(
|
matrix_bd = torch.matmul(
|
||||||
@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) * scaling # (batch, head, time1, time2)
|
) * scaling # (batch, head, time1, time2)
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
assert list(attn_output_weights.size()) == [
|
assert list(attn_output_weights.size()) == [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||||
attn_output = (
|
attn_output = (
|
||||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
attn_output.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(tgt_len, bsz, embed_dim)
|
||||||
|
)
|
||||||
|
attn_output = nn.functional.linear(
|
||||||
|
attn_output, out_proj_weight, out_proj_bias
|
||||||
)
|
)
|
||||||
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
|
||||||
|
|
||||||
if need_weights:
|
if need_weights:
|
||||||
# average attention weights over heads
|
# average attention weights over heads
|
||||||
@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
|
def __init__(
|
||||||
|
self, channels: int, kernel_size: int, bias: bool = True
|
||||||
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
@ -59,19 +59,16 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=49,
|
default=49,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=20,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -416,7 +413,9 @@ def decode_dataset(
|
|||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -444,7 +443,9 @@ def save_results(
|
|||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||||
@ -452,7 +453,9 @@ def save_results(
|
|||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info(
|
||||||
|
"Wrote detailed error stats to {}".format(errs_filename)
|
||||||
|
)
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
|
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
|
||||||
@ -547,7 +550,9 @@ def main():
|
|||||||
|
|
||||||
if params.export:
|
if params.export:
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -576,7 +581,9 @@ def main():
|
|||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
save_results(
|
||||||
|
params=params, test_set_name=test_set, results_dict=results_dict
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
assert idim >= 7
|
assert idim >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
|
nn.Conv2d(
|
||||||
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
|
nn.Conv2d(
|
||||||
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
layers.append(
|
layers.append(
|
||||||
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
torch.nn.MaxPool2d(
|
||||||
|
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
cur_channels = block_dim
|
cur_channels = block_dim
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = nn.Linear(
|
||||||
|
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
|
@ -511,7 +511,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -623,7 +625,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -149,7 +149,9 @@ class Transformer(nn.Module):
|
|||||||
norm=decoder_norm,
|
norm=decoder_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
|
self.decoder_output_layer = torch.nn.Linear(
|
||||||
|
d_model, self.decoder_num_class
|
||||||
|
)
|
||||||
|
|
||||||
self.decoder_criterion = LabelSmoothingLoss()
|
self.decoder_criterion = LabelSmoothingLoss()
|
||||||
else:
|
else:
|
||||||
@ -181,7 +183,9 @@ class Transformer(nn.Module):
|
|||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
|
x, supervision
|
||||||
|
)
|
||||||
x = self.ctc_output(encoder_memory)
|
x = self.ctc_output(encoder_memory)
|
||||||
return x, encoder_memory, memory_key_padding_mask
|
return x, encoder_memory, memory_key_padding_mask
|
||||||
|
|
||||||
@ -262,17 +266,23 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
ys_in_pad = pad_sequence(
|
||||||
|
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||||
|
)
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
ys_out_pad = pad_sequence(
|
||||||
|
ys_out, batch_first=True, padding_value=float(-1)
|
||||||
|
)
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device)
|
ys_in_pad = ys_in_pad.to(device)
|
||||||
ys_out_pad = ys_out_pad.to(device)
|
ys_out_pad = ys_out_pad.to(device)
|
||||||
|
|
||||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
# TODO: Use length information to create the decoder padding mask
|
# TODO: Use length information to create the decoder padding mask
|
||||||
@ -333,17 +343,23 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
ys_in_pad = pad_sequence(
|
||||||
|
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||||
|
)
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
ys_out_pad = pad_sequence(
|
||||||
|
ys_out, batch_first=True, padding_value=float(-1)
|
||||||
|
)
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||||
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
||||||
|
|
||||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
# TODO: Use length information to create the decoder padding mask
|
# TODO: Use length information to create the decoder padding mask
|
||||||
@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
|
|||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
return nn.functional.gelu
|
return nn.functional.gelu
|
||||||
|
|
||||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
raise RuntimeError(
|
||||||
|
"activation should be relu/gelu, not {}".format(activation)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
@ -818,7 +836,9 @@ def encoder_padding_mask(
|
|||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
).to(torch.int32)
|
||||||
|
|
||||||
lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
|
lengths = [
|
||||||
|
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
|
||||||
|
]
|
||||||
for idx in range(supervision_segments.size(0)):
|
for idx in range(supervision_segments.size(0)):
|
||||||
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
||||||
sequence_idx = supervision_segments[idx, 0].item()
|
sequence_idx = supervision_segments[idx, 0].item()
|
||||||
@ -839,7 +859,9 @@ def encoder_padding_mask(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
|
def decoder_padding_mask(
|
||||||
|
ys_pad: torch.Tensor, ignore_id: int = -1
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Generate a length mask for input.
|
"""Generate a length mask for input.
|
||||||
|
|
||||||
The masked position are filled with True,
|
The masked position are filled with True,
|
||||||
|
@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set
|
||||||
|
+ cut_set.perturb_speed(0.9)
|
||||||
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
@ -114,7 +116,9 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
@ -83,7 +83,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set
|
||||||
|
+ cut_set.perturb_speed(0.9)
|
||||||
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
@ -109,7 +111,9 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
|
|||||||
cur_state = loop_state
|
cur_state = loop_state
|
||||||
|
|
||||||
word = word2id[word]
|
word = word2id[word]
|
||||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
pieces = [
|
||||||
|
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(len(pieces) - 1):
|
for i in range(len(pieces) - 1):
|
||||||
w = word if i == 0 else eps
|
w = word if i == 0 else eps
|
||||||
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
def generate_lexicon(
|
||||||
|
token_sym_table: Dict[str, int], words: List[str]
|
||||||
|
) -> Lexicon:
|
||||||
"""Generate a lexicon from a word list and token_sym_table.
|
"""Generate a lexicon from a word list and token_sym_table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -317,7 +317,9 @@ def lexicon_to_fst(
|
|||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
parser.add_argument(
|
||||||
|
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
|
|||||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa.draw("L.pdf", title="L")
|
fsa.draw("L.pdf", title="L")
|
||||||
|
|
||||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
fsa_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||||
|
)
|
||||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||||
|
@ -76,7 +76,11 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -114,11 +118,9 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -186,7 +188,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -246,7 +249,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -258,7 +263,10 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -302,7 +310,11 @@ def decode_one_batch(
|
|||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
@ -375,7 +387,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -401,7 +415,9 @@ def save_results(
|
|||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -412,7 +428,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -456,7 +473,9 @@ def main():
|
|||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -485,7 +504,8 @@ def main():
|
|||||||
]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -50,7 +50,11 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
@ -83,11 +87,9 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -118,7 +120,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -154,7 +157,8 @@ def main():
|
|||||||
]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -187,7 +191,9 @@ def main():
|
|||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
filename = (
|
||||||
|
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
)
|
||||||
model.save(str(filename))
|
model.save(str(filename))
|
||||||
logging.info(f"Saved to {filename}")
|
logging.info(f"Saved to {filename}")
|
||||||
else:
|
else:
|
||||||
@ -195,14 +201,17 @@ def main():
|
|||||||
# Save it using a format so that it can be loaded
|
# Save it using a format so that it can be loaded
|
||||||
# by :func:`load_checkpoint`
|
# by :func:`load_checkpoint`
|
||||||
filename = (
|
filename = (
|
||||||
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
params.exp_dir
|
||||||
|
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
)
|
)
|
||||||
torch.save({"model": model.state_dict()}, str(filename))
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
logging.info(f"Saved to {filename}")
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -87,11 +87,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -117,12 +115,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -169,16 +165,15 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help=(
|
help="Maximum number of symbols per frame. "
|
||||||
"Maximum number of symbols per frame. "
|
"Use only when --method is greedy_search",
|
||||||
"Use only when --method is greedy_search"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -201,9 +196,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -260,9 +256,13 @@ def main():
|
|||||||
feature_lens = [f.size(0) for f in features]
|
feature_lens = [f.size(0) for f in features]
|
||||||
feature_lens = torch.tensor(feature_lens, device=device)
|
feature_lens = torch.tensor(feature_lens, device=device)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyp_list = []
|
hyp_list = []
|
||||||
@ -310,7 +310,9 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.method}"
|
||||||
|
)
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -327,7 +329,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -49,6 +49,7 @@ import optim
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -74,7 +75,9 @@ from icefall.env import get_env_info
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[
|
||||||
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
@ -200,7 +203,8 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="The initial learning rate. This value should not need to be changed.",
|
help="The initial learning rate. This value should not need "
|
||||||
|
"to be changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -223,45 +227,42 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=(
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"The prune range for rnnt loss, it means how many symbols(context)"
|
"we are using to compute the loss",
|
||||||
"we are using to compute the loss"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.25,
|
default=0.25,
|
||||||
help=(
|
help="The scale to smooth the loss with lm "
|
||||||
"The scale to smooth the loss with lm (output of prediction network) part."
|
"(output of prediction network) part.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)part.",
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--simple-loss-scale",
|
"--simple-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help=(
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
"To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
"with this parameter before adding to the final loss."
|
"with this parameter before adding to the final loss.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -560,7 +561,11 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = (
|
||||||
|
model.device
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else next(model.parameters()).device
|
||||||
|
)
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -588,16 +593,23 @@ def compute_loss(
|
|||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
0.0
|
||||||
|
if warmup < 1.0
|
||||||
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -713,7 +725,9 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(
|
||||||
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -1015,7 +1029,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(
|
||||||
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@ -121,24 +121,20 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
"`epoch` are loaded for averaging. "
|
"`epoch` are loaded for averaging. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -206,7 +202,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -266,7 +263,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -278,7 +277,10 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -322,7 +324,11 @@ def decode_one_batch(
|
|||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
@ -395,7 +401,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -421,7 +429,9 @@ def save_results(
|
|||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -432,7 +442,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
@ -477,7 +488,9 @@ def main():
|
|||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -505,12 +518,13 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -537,12 +551,13 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg + 1
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg + 1]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg + 1:
|
elif len(filenames) < params.avg + 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -571,7 +586,7 @@ def main():
|
|||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
"Calculating the averaged model over epoch range from "
|
f"Calculating the averaged model over epoch range from "
|
||||||
f"{start} (excluded) to {params.epoch}"
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
@ -88,24 +88,20 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
"`epoch` are loaded for averaging. "
|
"`epoch` are loaded for averaging. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -136,7 +132,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -169,12 +166,13 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -197,12 +195,13 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg + 1
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg + 1]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg + 1:
|
elif len(filenames) < params.avg + 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -230,7 +229,7 @@ def main():
|
|||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
"Calculating the averaged model over epoch range from "
|
f"Calculating the averaged model over epoch range from "
|
||||||
f"{start} (excluded) to {params.epoch}"
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -253,7 +252,9 @@ def main():
|
|||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
filename = (
|
||||||
|
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
)
|
||||||
model.save(str(filename))
|
model.save(str(filename))
|
||||||
logging.info(f"Saved to {filename}")
|
logging.info(f"Saved to {filename}")
|
||||||
else:
|
else:
|
||||||
@ -261,14 +262,17 @@ def main():
|
|||||||
# Save it using a format so that it can be loaded
|
# Save it using a format so that it can be loaded
|
||||||
# by :func:`load_checkpoint`
|
# by :func:`load_checkpoint`
|
||||||
filename = (
|
filename = (
|
||||||
params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
params.exp_dir
|
||||||
|
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
)
|
)
|
||||||
torch.save({"model": model.state_dict()}, str(filename))
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
logging.info(f"Saved to {filename}")
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -84,7 +84,9 @@ class Transducer(nn.Module):
|
|||||||
self.decoder_datatang = decoder_datatang
|
self.decoder_datatang = decoder_datatang
|
||||||
self.joiner_datatang = joiner_datatang
|
self.joiner_datatang = joiner_datatang
|
||||||
|
|
||||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
self.simple_am_proj = ScaledLinear(
|
||||||
|
encoder_dim, vocab_size, initial_speed=0.5
|
||||||
|
)
|
||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
if decoder_datatang is not None:
|
if decoder_datatang is not None:
|
||||||
@ -177,7 +179,9 @@ class Transducer(nn.Module):
|
|||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
y_padded = y_padded.to(torch.int64)
|
||||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
boundary = torch.zeros(
|
||||||
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
|
)
|
||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = encoder_out_lens
|
boundary[:, 3] = encoder_out_lens
|
||||||
|
|
||||||
|
@ -87,11 +87,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -117,12 +115,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -169,16 +165,15 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help=(
|
help="Maximum number of symbols per frame. "
|
||||||
"Maximum number of symbols per frame. "
|
"Use only when --method is greedy_search",
|
||||||
"Use only when --method is greedy_search"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -201,9 +196,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -261,9 +257,13 @@ def main():
|
|||||||
feature_lens = [f.size(0) for f in features]
|
feature_lens = [f.size(0) for f in features]
|
||||||
feature_lens = torch.tensor(feature_lens, device=device)
|
feature_lens = torch.tensor(feature_lens, device=device)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyp_list = []
|
hyp_list = []
|
||||||
@ -311,7 +311,9 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.method}"
|
||||||
|
)
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -328,7 +330,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -96,7 +96,9 @@ from icefall.env import get_env_info
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[
|
||||||
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
@ -222,7 +224,8 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="The initial learning rate. This value should not need to be changed.",
|
help="The initial learning rate. This value should not need "
|
||||||
|
"to be changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -245,45 +248,42 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=(
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"The prune range for rnnt loss, it means how many symbols(context)"
|
"we are using to compute the loss",
|
||||||
"we are using to compute the loss"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.25,
|
default=0.25,
|
||||||
help=(
|
help="The scale to smooth the loss with lm "
|
||||||
"The scale to smooth the loss with lm (output of prediction network) part."
|
"(output of prediction network) part.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)part.",
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--simple-loss-scale",
|
"--simple-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help=(
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
"To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
"with this parameter before adding to the final loss."
|
"with this parameter before adding to the final loss.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -635,7 +635,11 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = (
|
||||||
|
model.device
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else next(model.parameters()).device
|
||||||
|
)
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -666,16 +670,23 @@ def compute_loss(
|
|||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
0.0
|
||||||
|
if warmup < 1.0
|
||||||
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -813,7 +824,9 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
if datatang_train_dl is not None:
|
if datatang_train_dl is not None:
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (
|
||||||
|
tot_loss * (1 - 1 / params.reset_interval)
|
||||||
|
) + loss_info
|
||||||
|
|
||||||
if aishell:
|
if aishell:
|
||||||
aishell_tot_loss = (
|
aishell_tot_loss = (
|
||||||
@ -834,7 +847,9 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(
|
||||||
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -877,7 +892,9 @@ def train_one_epoch(
|
|||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
if datatang_train_dl is not None:
|
if datatang_train_dl is not None:
|
||||||
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
|
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
|
||||||
tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
tot_loss_str = (
|
||||||
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tot_loss_str = ""
|
tot_loss_str = ""
|
||||||
datatang_str = ""
|
datatang_str = ""
|
||||||
@ -1059,7 +1076,9 @@ def run(rank, world_size, args):
|
|||||||
train_cuts = filter_short_and_long_utterances(train_cuts)
|
train_cuts = filter_short_and_long_utterances(train_cuts)
|
||||||
|
|
||||||
if args.enable_musan:
|
if args.enable_musan:
|
||||||
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cuts_musan = None
|
cuts_musan = None
|
||||||
|
|
||||||
@ -1074,7 +1093,9 @@ def run(rank, world_size, args):
|
|||||||
if params.datatang_prob > 0:
|
if params.datatang_prob > 0:
|
||||||
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
|
||||||
train_datatang_cuts = datatang.train_cuts()
|
train_datatang_cuts = datatang.train_cuts()
|
||||||
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
|
train_datatang_cuts = filter_short_and_long_utterances(
|
||||||
|
train_datatang_cuts
|
||||||
|
)
|
||||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||||
datatang_train_dl = asr_datamodule.train_dataloaders(
|
datatang_train_dl = asr_datamodule.train_dataloaders(
|
||||||
train_datatang_cuts,
|
train_datatang_cuts,
|
||||||
@ -1228,7 +1249,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(
|
||||||
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,12 +64,10 @@ class AishellAsrDataModule:
|
|||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="ASR data related options",
|
||||||
description=(
|
description="These options are used for the preparation of "
|
||||||
"These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc."
|
"augmentations, etc.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
@ -81,74 +79,59 @@ class AishellAsrDataModule:
|
|||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
default=200.0,
|
default=200.0,
|
||||||
help=(
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"Maximum pooled recordings duration (seconds) in a "
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
"single batch. You can reduce it if it causes CUDA OOM."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, the batches will come from buckets of "
|
||||||
"When enabled, the batches will come from buckets of "
|
"similar duration (saves padding frames).",
|
||||||
"similar duration (saves padding frames)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help=(
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"The number of buckets for the DynamicBucketingSampler"
|
"(you might want to increase it for larger datasets).",
|
||||||
"(you might want to increase it for larger datasets)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
"When enabled, utterances (cuts) will be concatenated "
|
"to minimize the amount of padding.",
|
||||||
"to minimize the amount of padding."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--duration-factor",
|
"--duration-factor",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
"Determines the maximum duration of a concatenated cut "
|
"relative to the duration of the longest cut in a batch.",
|
||||||
"relative to the duration of the longest cut in a batch."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gap",
|
"--gap",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="The amount of padding (in seconds) inserted between "
|
||||||
"The amount of padding (in seconds) inserted between "
|
|
||||||
"concatenated cuts. This padding is filled with noise when "
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
"noise augmentation is used."
|
"noise augmentation is used.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
"When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available."
|
"if available.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled (=default), the examples will be "
|
||||||
"When enabled (=default), the examples will be shuffled for each epoch."
|
"shuffled for each epoch.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--drop-last",
|
"--drop-last",
|
||||||
@ -160,18 +143,17 @@ class AishellAsrDataModule:
|
|||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, each batch will have the "
|
||||||
"When enabled, each batch will have the "
|
|
||||||
"field: batch['supervisions']['cut'] with the cuts that "
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
"were used to construct it."
|
"were used to construct it.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of training dataloader workers that collect the batches.",
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -185,40 +167,40 @@ class AishellAsrDataModule:
|
|||||||
"--spec-aug-time-warp-factor",
|
"--spec-aug-time-warp-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=80,
|
default=80,
|
||||||
help=(
|
help="Used only when --enable-spec-aug is True. "
|
||||||
"Used only when --enable-spec-aug is True. "
|
|
||||||
"It specifies the factor for time warping in SpecAugment. "
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
"Larger values mean more warping. "
|
"Larger values mean more warping. "
|
||||||
"A value less than 1 means to disable time warp."
|
"A value less than 1 means to disable time warp.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--enable-musan",
|
"--enable-musan",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"When enabled, select noise from MUSAN and mix it"
|
"with training dataset. ",
|
||||||
"with training dataset. "
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
)
|
)
|
||||||
# Cut concatenation should be the first transform in the list,
|
# Cut concatenation should be the first transform in the list,
|
||||||
@ -233,7 +215,9 @@ class AishellAsrDataModule:
|
|||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
# Set the value of num_frame_masks according to Lhotse's version.
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
# In different Lhotse's versions, the default of num_frame_masks is
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
# different.
|
# different.
|
||||||
@ -276,7 +260,9 @@ class AishellAsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -322,7 +308,9 @@ class AishellAsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -378,9 +366,13 @@ class AishellAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> List[CutSet]:
|
def test_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
|
||||||
|
)
|
||||||
|
@ -49,19 +49,16 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=19,
|
default=19,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--method",
|
"--method",
|
||||||
@ -268,7 +265,9 @@ def decode_dataset(
|
|||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -290,7 +289,9 @@ def save_results(
|
|||||||
# We compute CER for aishell dataset.
|
# We compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
|
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
@ -334,7 +335,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -359,7 +362,9 @@ def main():
|
|||||||
|
|
||||||
if params.export:
|
if params.export:
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -387,7 +392,9 @@ def main():
|
|||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
save_results(
|
||||||
|
params=params, test_set_name=test_set, results_dict=results_dict
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
|
|||||||
nn.BatchNorm1d(num_features=500, affine=False),
|
nn.BatchNorm1d(num_features=500, affine=False),
|
||||||
)
|
)
|
||||||
self.lstms = nn.ModuleList(
|
self.lstms = nn.ModuleList(
|
||||||
[nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
|
[
|
||||||
|
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
|
||||||
|
for _ in range(5)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.lstm_bnorms = nn.ModuleList(
|
self.lstm_bnorms = nn.ModuleList(
|
||||||
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
||||||
|
@ -41,11 +41,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -55,7 +53,9 @@ def get_parser():
|
|||||||
help="Path to words.txt",
|
help="Path to words.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
|
parser.add_argument(
|
||||||
|
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--method",
|
"--method",
|
||||||
@ -71,12 +71,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -114,9 +112,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -174,7 +173,9 @@ def main():
|
|||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
features = features.permute(0, 2, 1) # now features is [N, C, T]
|
features = features.permute(0, 2, 1) # now features is [N, C, T]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -218,7 +219,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -49,7 +49,12 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
encode_supervisions,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -47,9 +47,9 @@ def greedy_search(
|
|||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
|
decoder_input = torch.tensor(
|
||||||
1, context_size
|
[blank_id] * context_size, device=device
|
||||||
)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
@ -81,9 +81,9 @@ def greedy_search(
|
|||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
if y != blank_id:
|
if y != blank_id:
|
||||||
hyp.append(y)
|
hyp.append(y)
|
||||||
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
|
decoder_input = torch.tensor(
|
||||||
1, context_size
|
[hyp[-context_size:]], device=device
|
||||||
)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
@ -157,7 +157,9 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if length_norm:
|
if length_norm:
|
||||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
|
return max(
|
||||||
|
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||||
|
|
||||||
@ -244,9 +246,9 @@ def beam_search(
|
|||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
|
decoder_input = torch.tensor(
|
||||||
1, context_size
|
[blank_id] * context_size, device=device
|
||||||
)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
|
@ -155,7 +155,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
|
d_model, nhead, dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
@ -173,14 +175,18 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
self.norm_ff_macaron = nn.LayerNorm(
|
||||||
|
d_model
|
||||||
|
) # for the macaron style FNN module
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
|
|
||||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
self.norm_final = nn.LayerNorm(
|
||||||
|
d_model
|
||||||
|
) # for the final output of the block
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -214,7 +220,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
|
src = residual + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(src)
|
||||||
|
)
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
@ -333,7 +341,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
def __init__(
|
||||||
|
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||||
|
) -> None:
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -349,7 +359,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
|
x.device
|
||||||
|
):
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
# Suppose `i` means to the position of query vector and `j` means the
|
# Suppose `i` means to the position of query vector and `j` means the
|
||||||
@ -619,9 +631,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
q, k, v = nn.functional.linear(
|
||||||
3, dim=-1
|
query, in_proj_weight, in_proj_bias
|
||||||
)
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -689,25 +701,33 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
raise RuntimeError(
|
||||||
|
"The size of the 2D attn_mask is not correct."
|
||||||
|
)
|
||||||
elif attn_mask.dim() == 3:
|
elif attn_mask.dim() == 3:
|
||||||
if list(attn_mask.size()) != [
|
if list(attn_mask.size()) != [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
query.size(0),
|
query.size(0),
|
||||||
key.size(0),
|
key.size(0),
|
||||||
]:
|
]:
|
||||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
raise RuntimeError(
|
||||||
|
"The size of the 3D attn_mask is not correct."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
"attn_mask's dimension {} is not supported".format(
|
||||||
|
attn_mask.dim()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# attn_mask's dim is 3 now.
|
# attn_mask's dim is 3 now.
|
||||||
|
|
||||||
# convert ByteTensor key_padding_mask to bool
|
# convert ByteTensor key_padding_mask to bool
|
||||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
if (
|
||||||
|
key_padding_mask is not None
|
||||||
|
and key_padding_mask.dtype == torch.uint8
|
||||||
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor"
|
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||||
" instead."
|
|
||||||
)
|
)
|
||||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||||
|
|
||||||
@ -744,7 +764,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# first compute matrix a and matrix c
|
# first compute matrix a and matrix c
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
matrix_ac = torch.matmul(
|
||||||
|
q_with_bias_u, k
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(
|
matrix_bd = torch.matmul(
|
||||||
@ -756,7 +778,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) * scaling # (batch, head, time1, time2)
|
) * scaling # (batch, head, time1, time2)
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
assert list(attn_output_weights.size()) == [
|
assert list(attn_output_weights.size()) == [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
@ -790,9 +814,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||||
attn_output = (
|
attn_output = (
|
||||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
attn_output.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(tgt_len, bsz, embed_dim)
|
||||||
|
)
|
||||||
|
attn_output = nn.functional.linear(
|
||||||
|
attn_output, out_proj_weight, out_proj_bias
|
||||||
)
|
)
|
||||||
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
|
||||||
|
|
||||||
if need_weights:
|
if need_weights:
|
||||||
# average attention weights over heads
|
# average attention weights over heads
|
||||||
@ -815,7 +843,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
|
def __init__(
|
||||||
|
self, channels: int, kernel_size: int, bias: bool = True
|
||||||
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
@ -52,19 +52,16 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=10,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -102,7 +99,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -229,7 +227,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -248,7 +248,9 @@ def decode_one_batch(
|
|||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
@ -317,7 +319,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -342,7 +346,9 @@ def save_results(
|
|||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -353,7 +359,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
@ -423,7 +430,9 @@ def main():
|
|||||||
|
|
||||||
if params.export:
|
if params.export:
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
@ -86,7 +86,9 @@ class Decoder(nn.Module):
|
|||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
embedding_out = F.pad(
|
||||||
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
|
@ -69,20 +69,17 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=20,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=10,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -113,7 +110,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -245,7 +243,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -103,7 +103,9 @@ class Transducer(nn.Module):
|
|||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
y_padded = y_padded.to(torch.int64)
|
||||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
boundary = torch.zeros(
|
||||||
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
|
)
|
||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
|
@ -73,11 +73,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -102,12 +100,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -121,7 +117,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -214,9 +211,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -275,7 +273,9 @@ def main():
|
|||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
@ -319,7 +319,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -126,7 +126,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -388,7 +389,9 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -501,7 +504,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -620,7 +625,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
|
|||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
return nn.functional.gelu
|
return nn.functional.gelu
|
||||||
|
|
||||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
raise RuntimeError(
|
||||||
|
"activation should be relu/gelu, not {}".format(activation)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
|
@ -29,7 +29,10 @@ from lhotse.dataset import (
|
|||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
|
from lhotse.dataset.input_strategies import (
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
)
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
@ -43,69 +46,59 @@ class AsrDataModule:
|
|||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="ASR data related options",
|
||||||
description=(
|
description="These options are used for the preparation of "
|
||||||
"These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc."
|
"augmentations, etc.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
default=200.0,
|
default=200.0,
|
||||||
help=(
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"Maximum pooled recordings duration (seconds) in a "
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
"single batch. You can reduce it if it causes CUDA OOM."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, the batches will come from buckets of "
|
||||||
"When enabled, the batches will come from buckets of "
|
"similar duration (saves padding frames).",
|
||||||
"similar duration (saves padding frames)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help=(
|
help="The number of buckets for the DynamicBucketingSampler "
|
||||||
"The number of buckets for the DynamicBucketingSampler "
|
"(you might want to increase it for larger datasets).",
|
||||||
"(you might want to increase it for larger datasets)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled (=default), the examples will be "
|
||||||
"When enabled (=default), the examples will be shuffled for each epoch."
|
"shuffled for each epoch.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, each batch will have the "
|
||||||
"When enabled, each batch will have the "
|
|
||||||
"field: batch['supervisions']['cut'] with the cuts that "
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
"were used to construct it."
|
"were used to construct it.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of training dataloader workers that collect the batches.",
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -119,22 +112,18 @@ class AsrDataModule:
|
|||||||
"--spec-aug-time-warp-factor",
|
"--spec-aug-time-warp-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=80,
|
default=80,
|
||||||
help=(
|
help="Used only when --enable-spec-aug is True. "
|
||||||
"Used only when --enable-spec-aug is True. "
|
|
||||||
"It specifies the factor for time warping in SpecAugment. "
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
"Larger values mean more warping. "
|
"Larger values mean more warping. "
|
||||||
"A value less than 1 means to disable time warp."
|
"A value less than 1 means to disable time warp.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--enable-musan",
|
"--enable-musan",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"When enabled, select noise from MUSAN and mix it"
|
"with training dataset. ",
|
||||||
"with training dataset. "
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -148,11 +137,9 @@ class AsrDataModule:
|
|||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
"When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available. Used only in dev/test CutSet"
|
"if available. Used only in dev/test CutSet",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
@ -175,7 +162,9 @@ class AsrDataModule:
|
|||||||
if cuts_musan is not None:
|
if cuts_musan is not None:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
@ -184,7 +173,9 @@ class AsrDataModule:
|
|||||||
|
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
# Set the value of num_frame_masks according to Lhotse's version.
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
# In different Lhotse's versions, the default of num_frame_masks is
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
# different.
|
# different.
|
||||||
@ -261,7 +252,9 @@ class AsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -93,19 +93,16 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=10,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -173,7 +170,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -229,7 +227,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -241,7 +241,10 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -285,7 +288,11 @@ def decode_one_batch(
|
|||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
@ -358,7 +365,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -384,7 +393,9 @@ def save_results(
|
|||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -395,7 +406,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
@ -436,7 +448,9 @@ def main():
|
|||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
@ -68,20 +68,17 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=20,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=10,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -112,7 +109,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -243,7 +241,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -87,11 +87,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -117,12 +115,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -169,16 +165,15 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help=(
|
help="Maximum number of symbols per frame. "
|
||||||
"Maximum number of symbols per frame. "
|
"Use only when --method is greedy_search",
|
||||||
"Use only when --method is greedy_search"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -199,9 +194,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -258,9 +254,13 @@ def main():
|
|||||||
feature_lens = [f.size(0) for f in features]
|
feature_lens = [f.size(0) for f in features]
|
||||||
feature_lens = torch.tensor(feature_lens, device=device)
|
feature_lens = torch.tensor(feature_lens, device=device)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyp_list = []
|
hyp_list = []
|
||||||
@ -308,7 +308,9 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.method}"
|
||||||
|
)
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -325,7 +327,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -149,7 +149,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -167,7 +168,8 @@ def get_parser():
|
|||||||
"--datatang-prob",
|
"--datatang-prob",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.2,
|
default=0.2,
|
||||||
help="The probability to select a batch from the aidatatang_200zh dataset",
|
help="The probability to select a batch from the "
|
||||||
|
"aidatatang_200zh dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -447,7 +449,9 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -601,7 +605,9 @@ def train_one_epoch(
|
|||||||
f"train/current_{prefix}_",
|
f"train/current_{prefix}_",
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
aishell_tot_loss.write_summary(
|
aishell_tot_loss.write_summary(
|
||||||
tb_writer, "train/aishell_tot_", params.batch_idx_train
|
tb_writer, "train/aishell_tot_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
@ -729,7 +735,9 @@ def run(rank, world_size, args):
|
|||||||
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
|
||||||
|
|
||||||
if args.enable_musan:
|
if args.enable_musan:
|
||||||
cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cuts_musan = None
|
cuts_musan = None
|
||||||
|
|
||||||
@ -768,7 +776,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -94,19 +94,16 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=10,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -174,7 +171,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -233,7 +231,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -245,7 +245,10 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -289,7 +292,11 @@ def decode_one_batch(
|
|||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
@ -362,7 +369,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -388,7 +397,9 @@ def save_results(
|
|||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append(
|
||||||
|
(res[0], list("".join(res[1])), list("".join(res[2])))
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -399,7 +410,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
@ -440,7 +452,9 @@ def main():
|
|||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
@ -68,20 +68,17 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=20,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=10,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -112,7 +109,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -243,7 +241,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -87,11 +87,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -117,12 +115,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -169,16 +165,15 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help=(
|
help="Maximum number of symbols per frame. "
|
||||||
"Maximum number of symbols per frame. "
|
"Use only when --method is greedy_search",
|
||||||
"Use only when --method is greedy_search"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -199,9 +194,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -258,9 +254,13 @@ def main():
|
|||||||
feature_lens = [f.size(0) for f in features]
|
feature_lens = [f.size(0) for f in features]
|
||||||
feature_lens = torch.tensor(feature_lens, device=device)
|
feature_lens = torch.tensor(feature_lens, device=device)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyp_list = []
|
hyp_list = []
|
||||||
@ -308,7 +308,9 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.method}"
|
||||||
|
)
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -325,7 +327,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -142,7 +142,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -413,7 +414,9 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -526,7 +529,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -652,7 +657,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
0
egs/aishell2/ASR/local/__init__.py
Normal file → Executable file
0
egs/aishell2/ASR/local/__init__.py
Normal file → Executable file
@ -83,7 +83,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set
|
||||||
|
+ cut_set.perturb_speed(0.9)
|
||||||
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
@ -109,7 +111,9 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
0
egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
Normal file → Executable file
0
egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
Normal file → Executable file
100
egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
Normal file → Executable file
100
egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
Normal file → Executable file
@ -76,12 +76,10 @@ class AiShell2AsrDataModule:
|
|||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="ASR data related options",
|
||||||
description=(
|
description="These options are used for the preparation of "
|
||||||
"These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc."
|
"augmentations, etc.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
@ -93,74 +91,59 @@ class AiShell2AsrDataModule:
|
|||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
default=200.0,
|
default=200.0,
|
||||||
help=(
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"Maximum pooled recordings duration (seconds) in a "
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
"single batch. You can reduce it if it causes CUDA OOM."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, the batches will come from buckets of "
|
||||||
"When enabled, the batches will come from buckets of "
|
"similar duration (saves padding frames).",
|
||||||
"similar duration (saves padding frames)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help=(
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"The number of buckets for the DynamicBucketingSampler"
|
"(you might want to increase it for larger datasets).",
|
||||||
"(you might want to increase it for larger datasets)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
"When enabled, utterances (cuts) will be concatenated "
|
"to minimize the amount of padding.",
|
||||||
"to minimize the amount of padding."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--duration-factor",
|
"--duration-factor",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
"Determines the maximum duration of a concatenated cut "
|
"relative to the duration of the longest cut in a batch.",
|
||||||
"relative to the duration of the longest cut in a batch."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gap",
|
"--gap",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="The amount of padding (in seconds) inserted between "
|
||||||
"The amount of padding (in seconds) inserted between "
|
|
||||||
"concatenated cuts. This padding is filled with noise when "
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
"noise augmentation is used."
|
"noise augmentation is used.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
"When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available."
|
"if available.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled (=default), the examples will be "
|
||||||
"When enabled (=default), the examples will be shuffled for each epoch."
|
"shuffled for each epoch.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--drop-last",
|
"--drop-last",
|
||||||
@ -172,18 +155,17 @@ class AiShell2AsrDataModule:
|
|||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, each batch will have the "
|
||||||
"When enabled, each batch will have the "
|
|
||||||
"field: batch['supervisions']['cut'] with the cuts that "
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
"were used to construct it."
|
"were used to construct it.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of training dataloader workers that collect the batches.",
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -197,22 +179,18 @@ class AiShell2AsrDataModule:
|
|||||||
"--spec-aug-time-warp-factor",
|
"--spec-aug-time-warp-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=80,
|
default=80,
|
||||||
help=(
|
help="Used only when --enable-spec-aug is True. "
|
||||||
"Used only when --enable-spec-aug is True. "
|
|
||||||
"It specifies the factor for time warping in SpecAugment. "
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
"Larger values mean more warping. "
|
"Larger values mean more warping. "
|
||||||
"A value less than 1 means to disable time warp."
|
"A value less than 1 means to disable time warp.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--enable-musan",
|
"--enable-musan",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"When enabled, select noise from MUSAN and mix it"
|
"with training dataset. ",
|
||||||
"with training dataset. "
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -238,16 +216,20 @@ class AiShell2AsrDataModule:
|
|||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
)
|
)
|
||||||
# Cut concatenation should be the first transform in the list,
|
# Cut concatenation should be the first transform in the list,
|
||||||
@ -262,7 +244,9 @@ class AiShell2AsrDataModule:
|
|||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
# Set the value of num_frame_masks according to Lhotse's version.
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
# In different Lhotse's versions, the default of num_frame_masks is
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
# different.
|
# different.
|
||||||
@ -306,7 +290,9 @@ class AiShell2AsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -362,7 +348,9 @@ class AiShell2AsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -418,7 +406,9 @@ class AiShell2AsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
|
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> CutSet:
|
def test_cuts(self) -> CutSet:
|
||||||
|
@ -168,24 +168,20 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
"`epoch` are loaded for averaging. "
|
"`epoch` are loaded for averaging. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -273,7 +269,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -351,7 +348,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -410,7 +409,10 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -536,7 +538,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -569,7 +573,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -620,7 +625,9 @@ def main():
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -654,12 +661,13 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -682,12 +690,13 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg + 1
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg + 1]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg + 1:
|
elif len(filenames) < params.avg + 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -715,7 +724,7 @@ def main():
|
|||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
"Calculating the averaged model over epoch range from "
|
f"Calculating the averaged model over epoch range from "
|
||||||
f"{start} (excluded) to {params.epoch}"
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -740,7 +749,9 @@ def main():
|
|||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(
|
||||||
|
params.vocab_size - 1, device=device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
@ -89,24 +89,20 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
"`epoch` are loaded for averaging. "
|
"`epoch` are loaded for averaging. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -137,7 +133,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -170,12 +167,13 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -198,12 +196,13 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg + 1
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg + 1]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg + 1:
|
elif len(filenames) < params.avg + 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -231,7 +230,7 @@ def main():
|
|||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
"Calculating the averaged model over epoch range from "
|
f"Calculating the averaged model over epoch range from "
|
||||||
f"{start} (excluded) to {params.epoch}"
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -267,7 +266,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -81,11 +81,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -111,12 +109,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -163,7 +159,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -194,9 +191,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -256,11 +254,15 @@ def main():
|
|||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lengths
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -332,7 +334,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -92,7 +92,9 @@ from icefall.env import get_env_info
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[
|
||||||
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
@ -218,7 +220,8 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="The initial learning rate. This value should not need to be changed.",
|
help="The initial learning rate. This value should not need "
|
||||||
|
"to be changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -241,45 +244,42 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=(
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"The prune range for rnnt loss, it means how many symbols(context)"
|
"we are using to compute the loss",
|
||||||
"we are using to compute the loss"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.25,
|
default=0.25,
|
||||||
help=(
|
help="The scale to smooth the loss with lm "
|
||||||
"The scale to smooth the loss with lm (output of prediction network) part."
|
"(output of prediction network) part.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)part.",
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--simple-loss-scale",
|
"--simple-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help=(
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
"To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
"with this parameter before adding to the final loss."
|
"with this parameter before adding to the final loss.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -603,7 +603,11 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = (
|
||||||
|
model.device
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else next(model.parameters()).device
|
||||||
|
)
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -632,16 +636,23 @@ def compute_loss(
|
|||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
0.0
|
||||||
|
if warmup < 1.0
|
||||||
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -760,7 +771,9 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(
|
||||||
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -816,7 +829,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -1089,7 +1104,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(
|
||||||
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +85,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set
|
||||||
|
+ cut_set.perturb_speed(0.9)
|
||||||
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
@ -118,7 +120,9 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
|
|||||||
cur_state = loop_state
|
cur_state = loop_state
|
||||||
|
|
||||||
word = word2id[word]
|
word = word2id[word]
|
||||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
pieces = [
|
||||||
|
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(len(pieces) - 1):
|
for i in range(len(pieces) - 1):
|
||||||
w = word if i == 0 else eps
|
w = word if i == 0 else eps
|
||||||
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
def generate_lexicon(
|
||||||
|
token_sym_table: Dict[str, int], words: List[str]
|
||||||
|
) -> Lexicon:
|
||||||
"""Generate a lexicon from a word list and token_sym_table.
|
"""Generate a lexicon from a word list and token_sym_table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -317,7 +317,9 @@ def lexicon_to_fst(
|
|||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
parser.add_argument(
|
||||||
|
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
|
|||||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa.draw("L.pdf", title="L")
|
fsa.draw("L.pdf", title="L")
|
||||||
|
|
||||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
fsa_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||||
|
)
|
||||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||||
|
@ -50,15 +50,15 @@ def get_parser():
|
|||||||
"-n",
|
"-n",
|
||||||
default=1,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help=(
|
help="number of characters to split, i.e., \
|
||||||
"number of characters to split, i.e., aabb -> a a b"
|
aabb -> a a b b with -n 1 and aa bb with -n 2",
|
||||||
" b with -n 1 and aa bb with -n 2"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||||
)
|
)
|
||||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
parser.add_argument(
|
||||||
|
"--space", default="<space>", type=str, help="space symbol"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--non-lang-syms",
|
"--non-lang-syms",
|
||||||
"-l",
|
"-l",
|
||||||
@ -66,7 +66,9 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||||
)
|
)
|
||||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
parser.add_argument(
|
||||||
|
"text", type=str, default=False, nargs="?", help="input text"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trans_type",
|
"--trans_type",
|
||||||
"-t",
|
"-t",
|
||||||
@ -106,7 +108,8 @@ def token2id(
|
|||||||
if token_type == "lazy_pinyin":
|
if token_type == "lazy_pinyin":
|
||||||
text = lazy_pinyin(chars_list)
|
text = lazy_pinyin(chars_list)
|
||||||
sub_ids = [
|
sub_ids = [
|
||||||
token_table[txt] if txt in token_table else oov_id for txt in text
|
token_table[txt] if txt in token_table else oov_id
|
||||||
|
for txt in text
|
||||||
]
|
]
|
||||||
ids.append(sub_ids)
|
ids.append(sub_ids)
|
||||||
else: # token_type = "pinyin"
|
else: # token_type = "pinyin"
|
||||||
@ -132,7 +135,9 @@ def main():
|
|||||||
if args.text:
|
if args.text:
|
||||||
f = codecs.open(args.text, encoding="utf-8")
|
f = codecs.open(args.text, encoding="utf-8")
|
||||||
else:
|
else:
|
||||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
f = codecs.getreader("utf-8")(
|
||||||
|
sys.stdin if is_python2 else sys.stdin.buffer
|
||||||
|
)
|
||||||
|
|
||||||
sys.stdout = codecs.getwriter("utf-8")(
|
sys.stdout = codecs.getwriter("utf-8")(
|
||||||
sys.stdout if is_python2 else sys.stdout.buffer
|
sys.stdout if is_python2 else sys.stdout.buffer
|
||||||
|
@ -74,12 +74,10 @@ class Aishell4AsrDataModule:
|
|||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="ASR data related options",
|
||||||
description=(
|
description="These options are used for the preparation of "
|
||||||
"These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc."
|
"augmentations, etc.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -93,81 +91,66 @@ class Aishell4AsrDataModule:
|
|||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
default=200.0,
|
default=200.0,
|
||||||
help=(
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"Maximum pooled recordings duration (seconds) in a "
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
"single batch. You can reduce it if it causes CUDA OOM."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, the batches will come from buckets of "
|
||||||
"When enabled, the batches will come from buckets of "
|
"similar duration (saves padding frames).",
|
||||||
"similar duration (saves padding frames)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=300,
|
default=300,
|
||||||
help=(
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"The number of buckets for the DynamicBucketingSampler"
|
"(you might want to increase it for larger datasets).",
|
||||||
"(you might want to increase it for larger datasets)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
"When enabled, utterances (cuts) will be concatenated "
|
"to minimize the amount of padding.",
|
||||||
"to minimize the amount of padding."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--duration-factor",
|
"--duration-factor",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
"Determines the maximum duration of a concatenated cut "
|
"relative to the duration of the longest cut in a batch.",
|
||||||
"relative to the duration of the longest cut in a batch."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gap",
|
"--gap",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="The amount of padding (in seconds) inserted between "
|
||||||
"The amount of padding (in seconds) inserted between "
|
|
||||||
"concatenated cuts. This padding is filled with noise when "
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
"noise augmentation is used."
|
"noise augmentation is used.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
"When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available."
|
"if available.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled (=default), the examples will be "
|
||||||
"When enabled (=default), the examples will be shuffled for each epoch."
|
"shuffled for each epoch.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -181,18 +164,17 @@ class Aishell4AsrDataModule:
|
|||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, each batch will have the "
|
||||||
"When enabled, each batch will have the "
|
|
||||||
"field: batch['supervisions']['cut'] with the cuts that "
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
"were used to construct it."
|
"were used to construct it.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of training dataloader workers that collect the batches.",
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -206,22 +188,18 @@ class Aishell4AsrDataModule:
|
|||||||
"--spec-aug-time-warp-factor",
|
"--spec-aug-time-warp-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=80,
|
default=80,
|
||||||
help=(
|
help="Used only when --enable-spec-aug is True. "
|
||||||
"Used only when --enable-spec-aug is True. "
|
|
||||||
"It specifies the factor for time warping in SpecAugment. "
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
"Larger values mean more warping. "
|
"Larger values mean more warping. "
|
||||||
"A value less than 1 means to disable time warp."
|
"A value less than 1 means to disable time warp.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--enable-musan",
|
"--enable-musan",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"When enabled, select noise from MUSAN and mix it"
|
"with training dataset. ",
|
||||||
"with training dataset. "
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -244,20 +222,24 @@ class Aishell4AsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
)
|
)
|
||||||
# Cut concatenation should be the first transform in the list,
|
# Cut concatenation should be the first transform in the list,
|
||||||
@ -272,7 +254,9 @@ class Aishell4AsrDataModule:
|
|||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
# Set the value of num_frame_masks according to Lhotse's version.
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
# In different Lhotse's versions, the default of num_frame_masks is
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
# different.
|
# different.
|
||||||
@ -316,7 +300,9 @@ class Aishell4AsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -373,7 +359,9 @@ class Aishell4AsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -117,24 +117,20 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
"`epoch` are loaded for averaging. "
|
"`epoch` are loaded for averaging. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -205,7 +201,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -263,7 +260,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -278,7 +277,10 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -324,7 +326,11 @@ def decode_one_batch(
|
|||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
@ -395,7 +401,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -428,7 +436,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -471,7 +480,9 @@ def main():
|
|||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -499,12 +510,13 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -531,12 +543,13 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg + 1
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg + 1]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg + 1:
|
elif len(filenames) < params.avg + 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -565,7 +578,7 @@ def main():
|
|||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
"Calculating the averaged model over epoch range from "
|
f"Calculating the averaged model over epoch range from "
|
||||||
f"{start} (excluded) to {params.epoch}"
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
@ -89,24 +89,20 @@ def get_parser():
|
|||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'"
|
"'--epoch' and '--iter'",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
"Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
"`epoch` are loaded for averaging. "
|
"`epoch` are loaded for averaging. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -140,7 +136,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -172,12 +169,13 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg:
|
elif len(filenames) < params.avg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -204,12 +202,13 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg + 1
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg + 1]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg + 1:
|
elif len(filenames) < params.avg + 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -238,7 +237,7 @@ def main():
|
|||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
"Calculating the averaged model over epoch range from "
|
f"Calculating the averaged model over epoch range from "
|
||||||
f"{start} (excluded) to {params.epoch}"
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -277,7 +276,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -94,11 +94,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -124,12 +122,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -176,7 +172,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -207,9 +204,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -268,11 +266,15 @@ def main():
|
|||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lengths
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -304,7 +306,10 @@ def main():
|
|||||||
|
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -345,7 +350,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -85,7 +85,9 @@ from icefall.env import get_env_info
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[
|
||||||
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
@ -211,7 +213,8 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="The initial learning rate. This value should not need to be changed.",
|
help="The initial learning rate. This value should not need "
|
||||||
|
"to be changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -234,45 +237,42 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=(
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"The prune range for rnnt loss, it means how many symbols(context)"
|
"we are using to compute the loss",
|
||||||
"we are using to compute the loss"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.25,
|
default=0.25,
|
||||||
help=(
|
help="The scale to smooth the loss with lm "
|
||||||
"The scale to smooth the loss with lm (output of prediction network) part."
|
"(output of prediction network) part.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)part.",
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--simple-loss-scale",
|
"--simple-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help=(
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
"To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
"with this parameter before adding to the final loss."
|
"with this parameter before adding to the final loss.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -599,7 +599,11 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = (
|
||||||
|
model.device
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else next(model.parameters()).device
|
||||||
|
)
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -629,15 +633,22 @@ def compute_loss(
|
|||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
0.0
|
||||||
|
if warmup < 1.0
|
||||||
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -816,7 +827,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
|
@ -84,7 +84,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set
|
||||||
|
+ cut_set.perturb_speed(0.9)
|
||||||
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
cur_num_jobs = num_jobs if ex is None else 80
|
cur_num_jobs = num_jobs if ex is None else 80
|
||||||
cur_num_jobs = min(cur_num_jobs, len(cut_set))
|
cur_num_jobs = min(cur_num_jobs, len(cut_set))
|
||||||
@ -119,7 +121,9 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
|
|||||||
cur_state = loop_state
|
cur_state = loop_state
|
||||||
|
|
||||||
word = word2id[word]
|
word = word2id[word]
|
||||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
pieces = [
|
||||||
|
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(len(pieces) - 1):
|
for i in range(len(pieces) - 1):
|
||||||
w = word if i == 0 else eps
|
w = word if i == 0 else eps
|
||||||
@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
def generate_lexicon(
|
||||||
|
token_sym_table: Dict[str, int], words: List[str]
|
||||||
|
) -> Lexicon:
|
||||||
"""Generate a lexicon from a word list and token_sym_table.
|
"""Generate a lexicon from a word list and token_sym_table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -317,7 +317,9 @@ def lexicon_to_fst(
|
|||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
parser.add_argument(
|
||||||
|
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
|
|||||||
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa.draw("L.pdf", title="L")
|
fsa.draw("L.pdf", title="L")
|
||||||
|
|
||||||
fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
|
fsa_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||||
|
)
|
||||||
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||||
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||||
|
@ -30,8 +30,8 @@ with word segmenting:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import jieba
|
|
||||||
import paddle
|
import paddle
|
||||||
|
import jieba
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
paddle.enable_static()
|
paddle.enable_static()
|
||||||
|
@ -50,15 +50,15 @@ def get_parser():
|
|||||||
"-n",
|
"-n",
|
||||||
default=1,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help=(
|
help="number of characters to split, i.e., \
|
||||||
"number of characters to split, i.e., aabb -> a a b"
|
aabb -> a a b b with -n 1 and aa bb with -n 2",
|
||||||
" b with -n 1 and aa bb with -n 2"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||||
)
|
)
|
||||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
parser.add_argument(
|
||||||
|
"--space", default="<space>", type=str, help="space symbol"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--non-lang-syms",
|
"--non-lang-syms",
|
||||||
"-l",
|
"-l",
|
||||||
@ -66,7 +66,9 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||||
)
|
)
|
||||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
parser.add_argument(
|
||||||
|
"text", type=str, default=False, nargs="?", help="input text"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trans_type",
|
"--trans_type",
|
||||||
"-t",
|
"-t",
|
||||||
@ -106,7 +108,8 @@ def token2id(
|
|||||||
if token_type == "lazy_pinyin":
|
if token_type == "lazy_pinyin":
|
||||||
text = lazy_pinyin(chars_list)
|
text = lazy_pinyin(chars_list)
|
||||||
sub_ids = [
|
sub_ids = [
|
||||||
token_table[txt] if txt in token_table else oov_id for txt in text
|
token_table[txt] if txt in token_table else oov_id
|
||||||
|
for txt in text
|
||||||
]
|
]
|
||||||
ids.append(sub_ids)
|
ids.append(sub_ids)
|
||||||
else: # token_type = "pinyin"
|
else: # token_type = "pinyin"
|
||||||
@ -132,7 +135,9 @@ def main():
|
|||||||
if args.text:
|
if args.text:
|
||||||
f = codecs.open(args.text, encoding="utf-8")
|
f = codecs.open(args.text, encoding="utf-8")
|
||||||
else:
|
else:
|
||||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
f = codecs.getreader("utf-8")(
|
||||||
|
sys.stdin if is_python2 else sys.stdin.buffer
|
||||||
|
)
|
||||||
|
|
||||||
sys.stdout = codecs.getwriter("utf-8")(
|
sys.stdout = codecs.getwriter("utf-8")(
|
||||||
sys.stdout if is_python2 else sys.stdout.buffer
|
sys.stdout if is_python2 else sys.stdout.buffer
|
||||||
|
@ -81,12 +81,10 @@ class AlimeetingAsrDataModule:
|
|||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="ASR data related options",
|
||||||
description=(
|
description="These options are used for the preparation of "
|
||||||
"These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc."
|
"augmentations, etc.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
@ -98,91 +96,75 @@ class AlimeetingAsrDataModule:
|
|||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
default=200.0,
|
default=200.0,
|
||||||
help=(
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"Maximum pooled recordings duration (seconds) in a "
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
"single batch. You can reduce it if it causes CUDA OOM."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, the batches will come from buckets of "
|
||||||
"When enabled, the batches will come from buckets of "
|
"similar duration (saves padding frames).",
|
||||||
"similar duration (saves padding frames)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=300,
|
default=300,
|
||||||
help=(
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"The number of buckets for the DynamicBucketingSampler"
|
"(you might want to increase it for larger datasets).",
|
||||||
"(you might want to increase it for larger datasets)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
"When enabled, utterances (cuts) will be concatenated "
|
"to minimize the amount of padding.",
|
||||||
"to minimize the amount of padding."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--duration-factor",
|
"--duration-factor",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
"Determines the maximum duration of a concatenated cut "
|
"relative to the duration of the longest cut in a batch.",
|
||||||
"relative to the duration of the longest cut in a batch."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gap",
|
"--gap",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help=(
|
help="The amount of padding (in seconds) inserted between "
|
||||||
"The amount of padding (in seconds) inserted between "
|
|
||||||
"concatenated cuts. This padding is filled with noise when "
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
"noise augmentation is used."
|
"noise augmentation is used.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help=(
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
"When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available."
|
"if available.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled (=default), the examples will be "
|
||||||
"When enabled (=default), the examples will be shuffled for each epoch."
|
"shuffled for each epoch.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, each batch will have the "
|
||||||
"When enabled, each batch will have the "
|
|
||||||
"field: batch['supervisions']['cut'] with the cuts that "
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
"were used to construct it."
|
"were used to construct it.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of training dataloader workers that collect the batches.",
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -196,22 +178,18 @@ class AlimeetingAsrDataModule:
|
|||||||
"--spec-aug-time-warp-factor",
|
"--spec-aug-time-warp-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=80,
|
default=80,
|
||||||
help=(
|
help="Used only when --enable-spec-aug is True. "
|
||||||
"Used only when --enable-spec-aug is True. "
|
|
||||||
"It specifies the factor for time warping in SpecAugment. "
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
"Larger values mean more warping. "
|
"Larger values mean more warping. "
|
||||||
"A value less than 1 means to disable time warp."
|
"A value less than 1 means to disable time warp.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--enable-musan",
|
"--enable-musan",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help=(
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"When enabled, select noise from MUSAN and mix it"
|
"with training dataset. ",
|
||||||
"with training dataset. "
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
@ -227,20 +205,24 @@ class AlimeetingAsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(
|
||||||
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(
|
||||||
|
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Using cut concatenation with duration factor "
|
f"Using cut concatenation with duration factor "
|
||||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
)
|
)
|
||||||
# Cut concatenation should be the first transform in the list,
|
# Cut concatenation should be the first transform in the list,
|
||||||
@ -255,7 +237,9 @@ class AlimeetingAsrDataModule:
|
|||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
logging.info(
|
||||||
|
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||||
|
)
|
||||||
# Set the value of num_frame_masks according to Lhotse's version.
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
# In different Lhotse's versions, the default of num_frame_masks is
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
# different.
|
# different.
|
||||||
@ -298,7 +282,9 @@ class AlimeetingAsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -355,7 +341,9 @@ class AlimeetingAsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -70,7 +70,11 @@ from beam_search import (
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -89,30 +93,25 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch",
|
"--batch",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help=(
|
help="It specifies the batch checkpoint to use for decoding."
|
||||||
"It specifies the batch checkpoint to use for decoding."
|
"Note: Epoch counts from 0.",
|
||||||
"Note: Epoch counts from 0."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -194,7 +193,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -249,7 +249,9 @@ def decode_one_batch(
|
|||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -264,7 +266,10 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -310,7 +315,11 @@ def decode_one_batch(
|
|||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
@ -381,7 +390,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -414,7 +425,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -551,7 +563,8 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
dev_shards = [
|
dev_shards = [
|
||||||
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
str(path)
|
||||||
|
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||||
]
|
]
|
||||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
cuts_dev_webdataset = CutSet.from_webdataset(
|
||||||
dev_shards,
|
dev_shards,
|
||||||
@ -561,7 +574,8 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_shards = [
|
test_shards = [
|
||||||
str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
|
str(path)
|
||||||
|
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
|
||||||
]
|
]
|
||||||
cuts_test_webdataset = CutSet.from_webdataset(
|
cuts_test_webdataset = CutSet.from_webdataset(
|
||||||
test_shards,
|
test_shards,
|
||||||
@ -574,7 +588,9 @@ def main():
|
|||||||
return 1.0 <= c.duration
|
return 1.0 <= c.duration
|
||||||
|
|
||||||
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
|
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
|
||||||
cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
|
cuts_test_webdataset = cuts_test_webdataset.filter(
|
||||||
|
remove_short_and_long_utt
|
||||||
|
)
|
||||||
|
|
||||||
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
|
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
|
||||||
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
|
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
|
||||||
|
@ -62,20 +62,17 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help=(
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
|
"Note: Epoch counts from 0.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help=(
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. "
|
"'--epoch'. ",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -106,7 +103,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -175,7 +173,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -85,11 +85,9 @@ def get_parser():
|
|||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help=(
|
help="Path to the checkpoint. "
|
||||||
"Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint()."
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -114,12 +112,10 @@ def get_parser():
|
|||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help=(
|
help="The input sound file(s) to transcribe. "
|
||||||
"The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz."
|
"The sample rate has to be 16kHz.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -166,7 +162,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -196,9 +193,10 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
assert sample_rate == expected_sample_rate, (
|
||||||
sample_rate == expected_sample_rate
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0])
|
||||||
return ans
|
return ans
|
||||||
@ -259,7 +257,9 @@ def main():
|
|||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
@ -284,7 +284,10 @@ def main():
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -336,7 +339,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
@ -81,7 +81,9 @@ from icefall.env import get_env_info
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[
|
||||||
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
|
]
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
|
|
||||||
@ -185,45 +187,42 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help=(
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"The prune range for rnnt loss, it means how many symbols(context)"
|
"we are using to compute the loss",
|
||||||
"we are using to compute the loss"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.25,
|
default=0.25,
|
||||||
help=(
|
help="The scale to smooth the loss with lm "
|
||||||
"The scale to smooth the loss with lm (output of prediction network) part."
|
"(output of prediction network) part.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)part.",
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--simple-loss-scale",
|
"--simple-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help=(
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
"To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
"with this parameter before adding to the final loss."
|
"with this parameter before adding to the final loss.",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -543,15 +542,22 @@ def compute_loss(
|
|||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
0.0
|
||||||
|
if warmup < 1.0
|
||||||
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -705,7 +711,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
|
@ -25,10 +25,15 @@ from random import Random
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
Fbank,
|
Fbank,
|
||||||
FbankConfig,
|
FbankConfig,
|
||||||
|
# fmt: off
|
||||||
|
# See the following for why LilcomChunkyWriter is preferred
|
||||||
|
# https://github.com/k2-fsa/icefall/pull/404
|
||||||
|
# https://github.com/lhotse-speech/lhotse/pull/527
|
||||||
|
# fmt: on
|
||||||
LilcomChunkyWriter,
|
LilcomChunkyWriter,
|
||||||
RecordingSet,
|
RecordingSet,
|
||||||
SupervisionSet,
|
SupervisionSet,
|
||||||
@ -76,13 +81,17 @@ def make_cutset_blueprints(
|
|||||||
cut_sets.append((f"eval{i}", cut_set))
|
cut_sets.append((f"eval{i}", cut_set))
|
||||||
|
|
||||||
# Create train and valid cuts
|
# Create train and valid cuts
|
||||||
logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
|
logging.info(
|
||||||
|
"Loading, trimming, and shuffling the remaining core+noncore cuts."
|
||||||
|
)
|
||||||
recording_set = RecordingSet.from_file(
|
recording_set = RecordingSet.from_file(
|
||||||
manifest_dir / "csj_recordings_core.jsonl.gz"
|
manifest_dir / "csj_recordings_core.jsonl.gz"
|
||||||
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
|
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
|
||||||
supervision_set = SupervisionSet.from_file(
|
supervision_set = SupervisionSet.from_file(
|
||||||
manifest_dir / "csj_supervisions_core.jsonl.gz"
|
manifest_dir / "csj_supervisions_core.jsonl.gz"
|
||||||
) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
|
) + SupervisionSet.from_file(
|
||||||
|
manifest_dir / "csj_supervisions_noncore.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=recording_set,
|
recordings=recording_set,
|
||||||
@ -92,12 +101,15 @@ def make_cutset_blueprints(
|
|||||||
cut_set = cut_set.shuffle(Random(RNG_SEED))
|
cut_set = cut_set.shuffle(Random(RNG_SEED))
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Creating valid and train cuts from core and noncore,split at {split}."
|
"Creating valid and train cuts from core and noncore,"
|
||||||
|
f"split at {split}."
|
||||||
)
|
)
|
||||||
valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
|
valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
|
||||||
|
|
||||||
train_set = CutSet.from_cuts(islice(cut_set, split, None))
|
train_set = CutSet.from_cuts(islice(cut_set, split, None))
|
||||||
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
|
train_set = (
|
||||||
|
train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
|
||||||
|
)
|
||||||
|
|
||||||
cut_sets.extend([("valid", valid_set), ("train", train_set)])
|
cut_sets.extend([("valid", valid_set), ("train", train_set)])
|
||||||
|
|
||||||
@ -110,9 +122,15 @@ def get_args():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
parser.add_argument(
|
||||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
"--manifest-dir", type=Path, help="Path to save manifests"
|
||||||
parser.add_argument("--split", type=int, default=4000, help="Split at this index")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fbank-dir", type=Path, help="Path to save fbank features"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--split", type=int, default=4000, help="Split at this index"
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -123,7 +141,9 @@ def main():
|
|||||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||||
num_jobs = min(16, os.cpu_count())
|
num_jobs = min(16, os.cpu_count())
|
||||||
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from lhotse.recipes.utils import read_manifests_if_cached
|
|||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
|
||||||
ARGPARSE_DESCRIPTION = """
|
ARGPARSE_DESCRIPTION = """
|
||||||
This file computes fbank features of the musan dataset.
|
This file computes fbank features of the musan dataset.
|
||||||
It looks for manifests in the directory data/manifests.
|
It looks for manifests in the directory data/manifests.
|
||||||
@ -83,7 +84,9 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
|
|||||||
# create chunks of Musan with duration 5 - 10 seconds
|
# create chunks of Musan with duration 5 - 10 seconds
|
||||||
musan_cuts = (
|
musan_cuts = (
|
||||||
CutSet.from_manifests(
|
CutSet.from_manifests(
|
||||||
recordings=combine(part["recordings"] for part in manifests.values())
|
recordings=combine(
|
||||||
|
part["recordings"] for part in manifests.values()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.cut_into_windows(10.0)
|
.cut_into_windows(10.0)
|
||||||
.filter(lambda c: c.duration > 5)
|
.filter(lambda c: c.duration > 5)
|
||||||
@ -104,15 +107,21 @@ def get_args():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
parser.add_argument(
|
||||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
"--manifest-dir", type=Path, help="Path to save manifests"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fbank-dir", type=Path, help="Path to save fbank features"
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
compute_fbank_musan(args.manifest_dir, args.fbank_dir)
|
compute_fbank_musan(args.manifest_dir, args.fbank_dir)
|
||||||
|
@ -318,3 +318,4 @@ spk_id = 2
|
|||||||
ャ = ǐa
|
ャ = ǐa
|
||||||
ュ = ǐu
|
ュ = ǐu
|
||||||
ョ = ǐo
|
ョ = ǐo
|
||||||
|
|
||||||
|
@ -318,3 +318,4 @@ spk_id = 2
|
|||||||
ャ = ǐa
|
ャ = ǐa
|
||||||
ュ = ǐu
|
ュ = ǐu
|
||||||
ョ = ǐo
|
ョ = ǐo
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user