[Ready to merge] Pruned transducer stateless5 recipe for AISHELL4 (#399)

* pruned-transducer-stateless5 recipe for aishell4

* pruned-transducer-stateless5 recipe for aishell4

* do some changes and text normalize

* do some changes

* add text normalize

* combine the training data and decode without webdataset

* update codes for merging

* Do a change for READMD.md
This commit is contained in:
Mingshuang Luo 2022-06-14 22:19:05 +08:00 committed by GitHub
parent 53f38c01d2
commit 5c3ee8bfcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 6000 additions and 1 deletions

View File

@ -31,6 +31,7 @@ We provide the following recipes:
- [Aidatatang_200zh][aidatatang_200zh]
- [WenetSpeech][wenetspeech]
- [Alimeeting][alimeeting]
- [Aishell4][aishell4]
### yesno
@ -270,6 +271,21 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
### Aishell4
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
The best CER(%) results:
| | test |
|----------------------|--------|
| greedy search | 29.89 |
| fast beam search | 28.91 |
| modified beam search | 29.08 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
## Deployment with C++
Once you have trained a model in icefall, you may want to deploy it with C++,
@ -298,6 +314,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2
[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
@ -307,5 +324,5 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[aidatatang_200zh]: egs/aidatatang_200zh/ASR
[wenetspeech]: egs/wenetspeech/ASR
[alimeeting]: egs/alimeeting/ASR
[aishell4]: egs/aishell4/ASR
[k2]: https://github.com/k2-fsa/k2
)

View File

@ -0,0 +1,19 @@
# Introduction
This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets).
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

117
egs/aishell4/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,117 @@
## Results
### Aishell4 Char training results (Pruned Transducer Stateless5)
#### 2022-06-13
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/399.
When use-averaged-model=False, the CERs are
| | test | comment |
|------------------------------------|------------|------------------------------------------|
| greedy search | 30.05 | --epoch 30, --avg 25, --max-duration 800 |
| modified beam search (beam size 4) | 29.16 | --epoch 30, --avg 25, --max-duration 800 |
| fast beam search (set as default) | 29.20 | --epoch 30, --avg 25, --max-duration 1500|
When use-averaged-model=True, the CERs are
| | test | comment |
|------------------------------------|------------|----------------------------------------------------------------------|
| greedy search | 29.89 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
| modified beam search (beam size 4) | 28.91 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
| fast beam search (set as default) | 29.08 | --iter 36000, --avg 8, --max-duration 1500 --use-averaged-model=True |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 220 \
--save-every-n 4000
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/tjaVRKERS8C10SzhpBcxSQ/#scalars
When use-averaged-model=False, the decoding command is:
```
epoch=30
avg=25
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
When use-averaged-model=True, the decoding command is:
```
iter=36000
avg=8
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--use-averaged-model True
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4 \
--use-averaged-model True
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--use-averaged-model True
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aishell4_pruned_transducer_stateless5>

View File

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the aidatatang_200zh dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80):
src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"train_S",
"train_M",
"train_L",
"test",
)
prefix = "aishell4"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=ChunkedLilcomHdf5Writer,
)
logging.info("About splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
min_duration=None,
)
cut_set.to_file(output_dir / cuts_filename)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,113 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()`
in ../../../librispeech/ASR/transducer/train.py
for usage.
"""
from lhotse import load_manifest
def main():
paths = [
"./data/fbank/cuts_train_S.json.gz",
"./data/fbank/cuts_train_M.json.gz",
"./data/fbank/cuts_train_L.json.gz",
"./data/fbank/cuts_test.json.gz",
]
for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
Starting display the statistics for ./data/fbank/cuts_train_S.json.gz
Cuts count: 91995
Total duration (hours): 95.8
Speech duration (hours): 95.8 (100.0%)
***
Duration statistics (seconds):
mean 3.7
std 7.1
min 0.1
25% 0.9
50% 2.5
75% 5.4
99% 15.3
99.5% 17.5
99.9% 23.3
max 1021.7
Starting display the statistics for ./data/fbank/cuts_train_M.json.gz
Cuts count: 177195
Total duration (hours): 179.5
Speech duration (hours): 179.5 (100.0%)
***
Duration statistics (seconds):
mean 3.6
std 6.4
min 0.0
25% 0.9
50% 2.4
75% 5.2
99% 14.9
99.5% 17.0
99.9% 23.5
max 990.4
Starting display the statistics for ./data/fbank/cuts_train_L.json.gz
Cuts count: 37572
Total duration (hours): 49.1
Speech duration (hours): 49.1 (100.0%)
***
Duration statistics (seconds):
mean 4.7
std 4.0
min 0.2
25% 1.6
50% 3.7
75% 6.7
99% 17.5
99.5% 19.8
99.9% 26.2
max 87.4
Starting display the statistics for ./data/fbank/cuts_test.json.gz
Cuts count: 10574
Total duration (hours): 12.1
Speech duration (hours): 12.1 (100.0%)
***
Duration statistics (seconds):
mean 4.1
std 3.4
min 0.2
25% 1.4
50% 3.2
75% 5.8
99% 14.4
99.5% 14.9
99.9% 16.5
max 17.9
"""

View File

@ -0,0 +1,248 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/text,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import re
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for word in words:
chars = list(word.strip(" \t"))
if contain_oov(token_sym_table, chars):
continue
lexicon.append((word, chars))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(text_file: str) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
text_file:
A file that contains text lines to generate tokens.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
whitespace = re.compile(r"([ \t\r\n]+)")
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
line = re.sub(whitespace, "", line)
chars = list(line)
for char in chars:
if char not in tokens:
tokens[char] = len(tokens)
return tokens
def main():
lang_dir = Path("data/lang_char")
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(text_file)
lexicon = generate_lexicon(token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,390 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args()
def main():
out_dir = Path(get_args().lang_dir)
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / "tokens.txt", token2id)
write_mapping(out_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input words.txt without ids:
- words_no_ids.txt
and generates the new words.txt with related ids.
- words.txt
"""
import argparse
import logging
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
description="Prepare words.txt",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
default="data/lang_char/words_no_ids.txt",
type=str,
help="the words file without ids for WenetSpeech",
)
parser.add_argument(
"--output-file",
default="data/lang_char/words.txt",
type=str,
help="the words file with ids for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
add_words = ["<eps> 0", "!SIL 1", "<SPOKEN_NOISE> 2", "<UNK> 3"]
new_lines.extend(add_words)
logging.info("Starting reading the input file")
for i in tqdm(range(len(lines))):
x = lines[i]
idx = 4 + i
new_line = str(x.strip("\n")) + " " + str(idx)
new_lines.append(new_line)
logging.info("Starting writing the words.txt")
f_out = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_out.write(line)
f_out.write("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,106 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import os
import tempfile
import k2
from prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,
get_words,
lexicon_to_fst,
read_lexicon,
write_lexicon,
write_mapping,
)
def generate_lexicon_file() -> str:
fd, filename = tempfile.mkstemp()
os.close(fd)
s = """
!SIL SIL
<SPOKEN_NOISE> SPN
<UNK> SPN
f f
a a
foo f o o
bar b a r
bark b a r k
food f o o d
food2 f o o d
fo f o
""".strip()
with open(filename, "w") as f:
f.write(s)
return filename
def test_read_lexicon(filename: str):
lexicon = read_lexicon(filename)
phones = get_phones(lexicon)
words = get_words(lexicon)
print(lexicon)
print(phones)
print(words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
print(lexicon_disambig)
print("max disambig:", f"#{max_disambig}")
phones = ["<eps>", "SIL", "SPN"] + phones
for i in range(max_disambig + 1):
phones.append(f"#{i}")
words = ["<eps>"] + words
phone2id = generate_id_map(phones)
word2id = generate_id_map(words)
print(phone2id)
print(word2id)
write_mapping("phones.txt", phone2id)
write_mapping("words.txt", word2id)
write_lexicon("a.txt", lexicon)
write_lexicon("a_disambig.txt", lexicon_disambig)
fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
def main():
filename = generate_lexicon_file()
test_read_lexicon(filename)
os.remove(filename)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input "text", which refers to the transcript file for
WenetSpeech:
- text
and generates the output file text_word_segmentation which is implemented
with word segmenting:
- text_words_segmentation
"""
import argparse
import jieba
from tqdm import tqdm
jieba.enable_paddle()
def get_parser():
parser = argparse.ArgumentParser(
description="Chinese Word Segmentation for text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
default="data/lang_char/text",
type=str,
help="the input text file for WenetSpeech",
)
parser.add_argument(
"--output-file",
default="data/lang_char/text_words_segmentation",
type=str,
help="the text implemented with words segmenting for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
for i in tqdm(range(len(lines))):
x = lines[i].rstrip()
seg_list = jieba.cut(x, use_paddle=True)
new_line = " ".join(seg_list)
new_lines.append(new_line)
f_new = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_new.write(line)
f_new.write("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,195 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe)
# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import codecs
import re
import sys
from typing import List
from pypinyin import lazy_pinyin, pinyin
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "pinyin", "lazy_pinyin"],
help="""Transcript type. char/pinyin/lazy_pinyin""",
)
return parser
def token2id(
texts, token_table, token_type: str = "lazy_pinyin", oov: str = "<unk>"
) -> List[List[int]]:
"""Convert token to id.
Args:
texts:
The input texts, it refers to the chinese text here.
token_table:
The token table is built based on "data/lang_xxx/token.txt"
token_type:
The type of token, such as "pinyin" and "lazy_pinyin".
oov:
Out of vocabulary token. When a word(token) in the transcript
does not exist in the token list, it is replaced with `oov`.
Returns:
The list of ids for the input texts.
"""
if texts is None:
raise ValueError("texts can't be None!")
else:
oov_id = token_table[oov]
ids: List[List[int]] = []
for text in texts:
chars_list = list(str(text))
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
text = pinyin(chars_list)
sub_ids = [
token_table[txt[0]] if txt[0] in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
return ids
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols :]) # noqa E203
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
if args.trans_type == "pinyin":
a = pinyin(list(str(a)))
a = [one[0] for one in a]
if args.trans_type == "lazy_pinyin":
a = lazy_pinyin(list(str(a)))
a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = "".join(a_flat)
print(a_chars)
line = f.readline()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input "text_full", which includes three transcript files
(train_S, train_M and train_L) for AISHELL4:
- text_full
and generates the output file text_normalize which is implemented
to normalize text:
- text
"""
import argparse
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
description="Normalizing for text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input",
default="data/lang_char/text_full",
type=str,
help="the input text files for AISHELL4",
)
parser.add_argument(
"--output",
default="data/lang_char/text",
type=str,
help="the text implemented with normalizer for AISHELL4",
)
return parser
def text_normalize(str_line: str):
line = str_line.strip().rstrip("\n")
line = line.replace(" ", "")
line = line.replace("<sil>", "")
line = line.replace("<%>", "")
line = line.replace("<->", "")
line = line.replace("<$>", "")
line = line.replace("<#>", "")
line = line.replace("<_>", "")
line = line.replace("<space>", "")
line = line.replace("`", "")
line = line.replace("&", "")
line = line.replace(",", "")
line = line.replace("", "")
line = line.replace("", "A")
line = line.replace("", "B")
line = line.replace("", "C")
line = line.replace("", "K")
line = line.replace("", "T")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("·", "")
line = line.replace("*", "")
line = line.replace("", "")
line = line.replace("$", "")
line = line.replace("+", "")
line = line.replace("-", "")
line = line.replace("\\", "")
line = line.replace("?", "")
line = line.replace("", "")
line = line.replace("%", "")
line = line.replace(".", "")
line = line.replace("<", "")
line = line.replace("", "")
line = line.upper()
return line
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input
output_file = args.output
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
for i in tqdm(range(len(lines))):
new_line = text_normalize(lines[i])
new_lines.append(new_line)
f_new = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_new.write(line)
f_new.write("\n")
if __name__ == "__main__":
main()

160
egs/aishell4/ASR/prepare.sh Executable file
View File

@ -0,0 +1,160 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/aishell4
# You can find four directories:train_S, train_M, train_L and test.
# You can download it from https://openslr.org/111/
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/aishell4,
# you can create a symlink
#
# ln -sfv /path/to/aishell4 $dl_dir/aishell4
#
if [ ! -f $dl_dir/aishell4/train_L ]; then
lhotse download aishell4 $dl_dir/aishell4
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/musan
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare aishell4 manifest"
# We assume that you have downloaded the aishell4 corpus
# to $dl_dir/aishell4
if [ ! -f data/manifests/aishell4/.manifests.done ]; then
mkdir -p data/manifests/aishell4
lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4
touch data/manifests/aishell4/.manifests.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank/aishell4
lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4
touch data/fbank/aishell4/.fbank.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank
./local/compute_fbank_musan.py
touch data/fbank/.msuan.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py
touch data/fbank/.aishell4.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir
# Prepare text.
# Note: in Linux, you can install jq with the following command:
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_S
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_M
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_L
for r in text_S text_M text_L ; do
cat $lang_char_dir/$r >> $lang_char_dir/text_full
done
# Prepare text normalize
python ./local/text_normalize.py \
--input $lang_char_dir/text_full \
--output $lang_char_dir/text
# Prepare words segments
python ./local/text2segments.py \
--input $lang_char_dir/text \
--output $lang_char_dir/text_words_segmentation
cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
| sort -u | sed "/^$/d" \
| uniq > $lang_char_dir/words_no_ids.txt
# Prepare words.txt
if [ ! -f $lang_char_dir/words.txt ]; then
./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py
fi
fi

View File

@ -0,0 +1,448 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class Aishell4AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=300,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=30000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_dl.sampler.load_state_dict(sampler_state_dict)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_S_cuts(self) -> CutSet:
logging.info("About to get S train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz"
)
@lru_cache()
def train_M_cuts(self) -> CutSet:
logging.info("About to get M train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz"
)
@lru_cache()
def train_L_cuts(self) -> CutSet:
logging.info("About to get L train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
# Aishell4 doesn't have dev data, here use test to replace dev.
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,630 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
When use-averaged-model=True, usage:
(1) greedy search
./pruned_transducer_stateless5/decode.py \
--iter 36000 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 800 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./pruned_transducer_stateless5/decode.py \
--iter 36000 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4 \
--use-averaged-model True
(3) fast beam search
./pruned_transducer_stateless5/decode.py \
--iter 36000 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 800 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--use-averaged-model True
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from local.text_normalize import text_normalize
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
Aishell4AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
aishell4 = Aishell4AsrDataModule(args)
test_cuts = aishell4.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)
test_dl = aishell4.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dl = [test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,275 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless5/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/aishell4/ASR
./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,358 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
When use-averaged-model=True, usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--decoding-method greedy_search \
--use-averaged-model True \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--use-averaged-model True \
--decoding-method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search (not suggest)
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--use-averaged-model True \
--decoding-method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--use-averaged-model True \
--decoding-method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
./pruned_transducer_stateless5/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Path to lang.
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--decoding-method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.decoding_method}"
if params.decoding_method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding-method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/aishell4/ASR
python ./pruned_transducer_stateless5/test_model.py
"""
from train import get_params, get_transducer_model
def test_model_1():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 24
params.dim_feedforward = 1536 # 384 * 4
params.encoder_dim = 384
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
def test_model_M():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 18
params.dim_feedforward = 1024
params.encoder_dim = 256
params.nhead = 4
params.decoder_dim = 512
params.joiner_dim = 512
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
# test_model_1()
test_model_M()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

1
egs/aishell4/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../egs/aishell/ASR/shared