Support Transformer LM (#750)

* support transformer LM

* show number of parameters during training

* update docstring

* testing files for ppl calculation

* add lm wrampper for rnn and transformer LM

* apply lm wrapper in lm shallow fusion

* small updates

* update decode.py to support LM fusion and LODR

* add export.py

* update CI and workflow

* update decoding results

* fix CI

* remove transformer LM from CI test
This commit is contained in:
marcoyang1998 2022-12-29 10:53:36 +08:00 committed by GitHub
parent 3c54333b06
commit 1f0408b103
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 3086 additions and 638 deletions

View File

@ -193,7 +193,7 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
ls -lh data
ls -lh lstm_transducer_stateless2/exp
log "Decoding test-clean and test-other"
log "Decoding test-clean and test-other with RNN LM"
./lstm_transducer_stateless2/decode.py \
--use-averaged-model 0 \
@ -201,12 +201,14 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
--avg 1 \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir $lm_repo/exp \
--rnn-lm-epoch 88 \
--rnn-lm-avg 1 \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir $lm_repo/exp \
--lm-epoch 88 \
--lm-avg 1 \
--lm-scale 0.3 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
fi
@ -245,11 +247,13 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then
--avg 1 \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_LODR \
--decoding-method modified_beam_search_LODR \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir $lm_repo/exp \
--rnn-lm-epoch 88 \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir $lm_repo/exp \
--lm-scale 0.4 \
--lm-epoch 88 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1 \

View File

@ -139,9 +139,10 @@ jobs:
cd egs/librispeech/ASR
tree lstm_transducer_stateless2/exp
cd lstm_transducer_stateless2/exp
echo "===modified_beam_search_rnnlm_shallow_fusion==="
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified_beam_search_lm_shallow_fusion==="
echo "===Using RNNLM==="
find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Display decoding results for lstm_transducer_stateless2
if: github.event.label.name == 'LODR'
@ -151,8 +152,8 @@ jobs:
tree lstm_transducer_stateless2/exp
cd lstm_transducer_stateless2/exp
echo "===modified_beam_search_rnnlm_LODR==="
find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for lstm_transducer_stateless2
uses: actions/upload-artifact@v2

View File

@ -320,6 +320,10 @@ Number of model parameters: 70369391, i.e., 70.37 M
|----------------------|------------|-------------|----------------------------------------|
| greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 |
| modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 |
| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 39 --avg 6 --max-duration 600 |
| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 39 --avg 6 --max-duration 600 |
| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 39 --avg 6 --max-duration 600 |
| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 39 --avg 6 --max-duration 600 |
| fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 |
The training commands are:
@ -458,7 +462,9 @@ The WERs are:
| greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 |
| modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 |
| modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 |
| modified_beam_search + RNNLM shallow fusion | 2.28 | 5.94 | --iter 468000 --avg 16 |
| modified_beam_search + TransformerLM shallow fusion | 2.37 | 6.48 | --iter 468000 --avg 16 |
| modified_beam_search + RNNLM + LODR | 2.24 | 5.89 | --iter 468000 --avg 16 |
| modified_beam_search + TransformerLM + LODR | 2.19 | 5.90 | --iter 468000 --avg 16 |
| fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 |
| greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 |
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
@ -513,9 +519,12 @@ for m in greedy_search fast_beam_search modified_beam_search; do
done
```
To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM
can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
You may also decode using shallow fusion with external neural network LM. To do so you need to
download a well-trained NN LM:
RNN LM: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
Transformer LM: <https://huggingface.co/marcoyang/icefall-librispeech-transformer-lm/tree/main>
```bash
for iter in 472000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
@ -523,23 +532,24 @@ for iter in 472000; do
--avg $avg \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
--decoding-method modified_beam_search_lm_shallow_fusion \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
done
done
```
You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in <https://arxiv.org/pdf/2203.16776.pdf>.
You may also decode using LODR + LM shallow fusion. This decoding method is proposed in <https://arxiv.org/pdf/2203.16776.pdf>.
It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be
generated by `generate-lm.sh`, or you may download it from <https://huggingface.co/marcoyang/librispeech_bigram>.
The decoding command is as follows:
```bash
for iter in 472000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
@ -547,18 +557,22 @@ for iter in 472000; do
--avg $avg \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_LODR \
--decoding-method modified_beam_search_LODR \
--beam 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1 \
--token-ngram 2 \
--max-contexts 4 \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \
--lm-epoch 99 \
--lm-scale 0.4 \
--lm-avg 1 \
--tokens-ngram 2 \
--ngram-lm-scale -0.16
done
done
```
Note that you can also set `--lm-type transformer` to use transformer LM during LODR. But it will be slower
because it has not been optimized. The pre-trained transformer LM is available at <https://huggingface.co/marcoyang/icefall-librispeech-transformer-lm/tree/main>
Pretrained models, training logs, decoding logs, and decoding results
are available at
@ -1717,6 +1731,9 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di
| greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search + RNNLM + LODR | 2.23 | 5.17 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search + TransformerLM shallow fusion | 2.27 | 5.26 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search + TransformerLM + LODR | 2.22 | 5.11 | --epoch 30 --avg 10 --max-duration 600 |
| fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
```bash
@ -2080,7 +2097,8 @@ subset so that the gigaspeech dataloader never exhausts.
| greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 |
| modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 |
| modified beam search + rnnlm shallow fusion | 1.94 | 4.2 | --iter 1224000 --avg 14 --max-duration 600 |
| modified beam search + LODR | 1.83 | 4.03 | --iter 1224000 --avg 14 --max-duration 600 |
| modified beam search + rnnlm + LODR | 1.77 | 3.99 | --iter 1224000 --avg 14 --max-duration 600 |
| modified beam search + TransformerLM + LODR | 1.75 | 3.94 | --iter 1224000 --avg 14 --max-duration 600 |
| fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 |
The training commands are:
@ -2126,8 +2144,10 @@ for iter in 1224000; do
done
done
```
You may also decode using shallow fusion with external RNNLM. To do so you need to
download a well-trained RNNLM from this link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
You may also decode using shallow fusion with external neural network LM. To do so you need to
download a well-trained NN LM:
RNN LM: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
Transformer LM: <https://huggingface.co/marcoyang/icefall-librispeech-transformer-lm/tree/main>
```bash
rnn_lm_scale=0.3

View File

@ -93,36 +93,37 @@ Usage:
--max-contexts 8 \
--max-states 64
(8) modified beam search (with RNNLM shallow fusion)
(8) modified beam search (with LM shallow fusion)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--lm-type rnn \
--lm-scale 0.3 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
(9) modified beam search with RNNLM shallow fusion + LODR
(9) modified beam search with LM shallow fusion + LODR
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--max-duration 600 \
--exp-dir ./lstm_transducer_stateless2/exp \
--decoding-method modified_beam_search_rnnlm_LODR \
--decoding-method modified_beam_search_LODR \
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--lm-type rnn \
--lm-scale 0.4 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1 \
--rnn-lm-tie-weights 1
--tokens-ngram 2 \
--ngram-lm-scale -0.16 \
"""
@ -148,14 +149,14 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
modified_beam_search_ngram_rescoring,
modified_beam_search_rnnlm_LODR,
modified_beam_search_rnnlm_shallow_fusion,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model
from icefall import NgramLm
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -163,7 +164,6 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
setup_logger,
@ -253,8 +253,8 @@ def get_parser():
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring
- modified_beam_search_rnnlm_shallow_fusion
- modified_beam_search_rnnlm_LODR
- modified_beam_search_lm_shallow_fusion
- modified_beam_search_LODR
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -344,67 +344,28 @@ def get_parser():
)
parser.add_argument(
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
@ -440,8 +401,7 @@ def decode_one_batch(
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -470,6 +430,9 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
set to true.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -581,20 +544,19 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
hyp_tokens = modified_beam_search_rnnlm_LODR(
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
@ -602,8 +564,7 @@ def decode_one_batch(
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -658,8 +619,7 @@ def decode_dataset(
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -678,6 +638,8 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -711,8 +673,7 @@ def decode_dataset(
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
LM=LM,
)
for name, hyps in hyps_dict.items():
@ -730,6 +691,7 @@ def decode_dataset(
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -781,6 +743,7 @@ def save_results(
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -795,9 +758,9 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_rnnlm_LODR",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_ngram_rescoring",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -820,12 +783,18 @@ def main():
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_shallow_fusion:
if params.lm_type == "rnn":
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
elif params.lm_type == "transformer":
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += "-LODR"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -954,28 +923,19 @@ def main():
ngram_lm = None
ngram_lm_scale = None
# only load rnnlm if used
if "rnnlm" in params.decoding_method:
rnn_lm_scale = params.rnn_lm_scale
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
# only load the neural network LM if doing shallow fusion
if params.use_shallow_fusion:
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
assert params.rnn_lm_avg == 1
LM.to(device)
LM.eval()
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
else:
rnn_lm_model = None
rnn_lm_scale = 0.0
LM = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
@ -1003,7 +963,9 @@ def main():
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts()
# test_clean_cuts = test_clean_cuts.subset(first=500)
test_other_cuts = librispeech.test_other_cuts()
# test_other_cuts = test_other_cuts.subset(first=500)
test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts)
test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts)
@ -1021,8 +983,7 @@ def main():
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
LM=LM,
)
save_results(

View File

@ -26,7 +26,9 @@ from model import Transducer
from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import (
DecodingResults,
add_eos,
@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring(
return ans
def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + RNNLM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
rnnlm (RnnLmModel):
RNNLM
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = []
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
lm_score = scores[count]
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestamp=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
def modified_beam_search_rnnlm_LODR(
def modified_beam_search_LODR(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
LODR_lm: NgramLm,
LODR_lm_scale: float,
rnnlm: RnnLmModel,
rnnlm_scale: float,
LM: LmScorer,
beam: int = 4,
) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR(
sp:
Sentence piece generator.
LODR_lm:
A low order n-gram LM
A low order n-gram LM, whose score will be subtracted during shallow fusion
LODR_lm_scale:
The scale of the LODR_lm
rnnlm (RnnLmModel):
RNNLM, the external language model
rnnlm_scale (float):
scale of RNNLM in shallow fusion
LM:
A neural net LM, e.g an RNNLM or transformer LM
beam (int, optional):
Beam size. Defaults to 4.
@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR(
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
assert LM is not None
lm_scale = LM.lm_scale
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR(
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
lens = torch.tensor([1]).to(device)
init_score, init_states = LM.score_token(sos_token, lens)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, # state of the RNNLM
state=init_states, # state of the NN LM
lm_score=init_score.reshape(-1),
state_cost=NgramLmStateCost(
LODR_lm
@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR(
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR(
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
LM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR(
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
if LM.lm_type == "rnn":
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append(
[sos_id] + hyp.ys[context_size:] + [new_token]
)
# forward RNNLM to get new states and scores
# forward NN LM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
@ -2305,18 +2084,19 @@ def modified_beam_search_rnnlm_LODR(
state_cost.lm_score,
hyp.state_cost.lm_score,
)
# score = score + RNNLM_score - LODR_score
# LODR_LM_scale is a negative number here
# score = score + TDLM_score - LODR_score
# LODR_LM_scale should be a negative number here
hyp_log_prob += (
lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score
) # add the lm score
lm_score = scores[count]
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
if LM.lm_type == "rnn":
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
else:
state_cost = hyp.state_cost
@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR(
ans.append(sorted_ans[unsorted_indices[i]])
return ans
def modified_beam_search_lm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
LM: LmScorer,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + NN LM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
LM (LmScorer):
A neural net LM, e.g RNN or Transformer
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert LM is not None
lm_scale = LM.lm_scale
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device)
init_score, init_states = LM.score_token(sos_token, lens)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
lm_scores = torch.cat(
[hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
`LM` will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = [] # a list of list
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
if LM.lm_type == "rnn":
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append(
[sos_id] + hyp.ys[context_size:] + [new_token]
)
if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
lm_score = scores[count]
if LM.lm_type == "rnn":
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestamp=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)

View File

@ -92,36 +92,37 @@ Usage:
--max-contexts 8 \
--max-states 64
(8) modified beam search (with RNNLM shallow fusion)
(8) modified beam search (with LM shallow fusion)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.3 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
(9) modified beam search with RNNLM shallow fusion + LODR
(9) modified beam search with LM shallow fusion + LODR
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--max-duration 600 \
--exp-dir ./pruned_transducer_stateless3/exp \
--decoding-method modified_beam_search_rnnlm_LODR \
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--decoding-method modified_beam_search_LODR \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.4 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1 \
--rnn-lm-tie-weights 1
--tokens-ngram 2 \
--ngram-lm-scale -0.16 \
"""
@ -149,14 +150,14 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
modified_beam_search_ngram_rescoring,
modified_beam_search_rnnlm_LODR,
modified_beam_search_rnnlm_shallow_fusion,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model
from icefall import NgramLm
from icefall import LmScorer, NgramLm
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
@ -240,8 +241,8 @@ def get_parser():
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring
- modified_beam_search_rnnlm_shallow_fusion
- modified_beam_search_rnnlm_LODR
- modified_beam_search_lm_shallow_fusion
- modified_beam_search_LODR
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -392,58 +393,28 @@ def get_parser():
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
"--use-shallow-fusion",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
@ -481,7 +452,7 @@ def decode_one_batch(
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnn_lm_model: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -515,10 +486,9 @@ def decode_one_batch(
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It an FsaVec containing an acceptor.
rnn_lm_model:
A rnnlm which can be used for rescoring or shallow fusion
rnnlm_scale:
The scale of the rnnlm.
LM:
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
set to true.
ngram_lm:
A ngram lm. Used in LODR decoding.
ngram_lm_scale:
@ -697,20 +667,19 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnn_lm_model,
rnnlm_scale=rnnlm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
hyp_tokens = modified_beam_search_rnnlm_LODR(
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
@ -718,8 +687,7 @@ def decode_one_batch(
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
rnnlm=rnn_lm_model,
rnnlm_scale=rnnlm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -812,7 +780,7 @@ def decode_dataset(
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnn_lm_model: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -836,6 +804,8 @@ def decode_dataset(
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It's an FsaVec containing an acceptor.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -871,7 +841,7 @@ def decode_dataset(
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
rnn_lm_model=rnn_lm_model,
rnnlm_scale=rnnlm_scale,
LM=LM,
)
for name, hyps in hyps_dict.items():
@ -1005,6 +975,7 @@ def load_ngram_LM(
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -1022,9 +993,9 @@ def main():
"modified_beam_search",
"fast_beam_search_with_nbest_rescoring",
"fast_beam_search_with_nbest_rnn_rescoring",
"modified_beam_search_rnnlm_LODR",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_ngram_rescoring",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -1055,12 +1026,18 @@ def main():
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-temperature-{params.temperature}"
if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += "-LODR"
if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_shallow_fusion:
if params.lm_type == "rnn":
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
elif params.lm_type == "transformer":
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -1195,28 +1172,19 @@ def main():
ngram_lm = None
ngram_lm_scale = None
# only load rnnlm if used
if "rnnlm" in params.decoding_method:
rnn_lm_scale = params.rnn_lm_scale
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
# only load the neural network LM if doing shallow fusion
if params.use_shallow_fusion:
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
assert params.rnn_lm_avg == 1
LM.to(device)
LM.eval()
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
else:
rnn_lm_model = None
rnn_lm_scale = 0.0
LM = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -1247,7 +1215,7 @@ def main():
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
rnn_lm_model=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
LM=LM,
)
save_results(

View File

@ -87,22 +87,39 @@ Usage:
--max-contexts 8 \
--max-states 64
(8) modified beam search with RNNLM shallow fusion (with LG)
(8) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.3 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
(9) modified beam search with LM shallow fusion + LODR
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--max-duration 600 \
--exp-dir ./pruned_transducer_stateless5/exp \
--decoding-method modified_beam_search_LODR \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.4 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
--tokens-ngram 2 \
--ngram-lm-scale -0.16 \
"""
@ -128,10 +145,13 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
modified_beam_search_ngram_rescoring,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -139,7 +159,6 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
setup_logger,
@ -229,7 +248,8 @@ def get_parser():
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
- modified_beam_search_LODR
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -342,69 +362,49 @@ def get_parser():
)
parser.add_argument(
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=3,
help="""Token Ngram used for rescoring.
Used only when the decoding method is
modified_beam_search_ngram_rescoring, or LODR
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="""ID of the backoff symbol.
Used only when the decoding method is
modified_beam_search_ngram_rescoring""",
)
add_model_arguments(parser)
return parser
@ -417,8 +417,9 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -447,6 +448,13 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
set to true.
ngram_lm:
A ngram lm. Used in LODR decoding.
ngram_lm_scale:
The scale of the ngram language model.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -559,15 +567,38 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
hyp_tokens = modified_beam_search_ngram_rescoring(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -620,8 +651,9 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -640,6 +672,8 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -663,7 +697,6 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
logging.info(f"Decoding {batch_idx}-th batch")
hyps_dict = decode_one_batch(
params=params,
@ -672,8 +705,9 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
for name, hyps in hyps_dict.items():
@ -742,6 +776,7 @@ def save_results(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -757,7 +792,8 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_rnnlm_shallow_fusion",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -783,7 +819,18 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_shallow_fusion:
if params.lm_type == "rnn":
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
elif params.lm_type == "transformer":
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -895,24 +942,34 @@ def main():
model.to(device)
model.eval()
rnn_lm_model = None
rnn_lm_scale = params.rnn_lm_scale
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
# only load N-gram LM when needed
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
assert params.rnn_lm_avg == 1
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
# only load the neural network LM if doing shallow fusion
if params.use_shallow_fusion:
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
LM.to(device)
LM.eval()
else:
LM = None
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
@ -955,8 +1012,9 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
save_results(

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -91,6 +92,41 @@ Usage:
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search with RNNLM shallow fusion
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method modified_beam_search_lm_shallow_fusion \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.3 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
(9) modified beam search with LM shallow fusion + LODR
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--max-duration 600 \
--exp-dir ./pruned_transducer_stateless5/exp \
--decoding-method modified_beam_search_LODR \
--beam-size 4 \
--lm-type rnn \
--lm-scale 0.4 \
--lm-exp-dir /path/to/LM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
--tokens-ngram 2 \
--ngram-lm-scale -0.16 \
"""
@ -115,9 +151,13 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
modified_beam_search_ngram_rescoring,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -213,6 +253,8 @@ def get_parser():
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
- modified_beam_search_LODR
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -274,6 +316,7 @@ def get_parser():
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
@ -323,6 +366,50 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=3,
help="""Token Ngram used for rescoring.
Used only when the decoding method is
modified_beam_search_ngram_rescoring, or LODR
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="""ID of the backoff symbol.
Used only when the decoding method is
modified_beam_search_ngram_rescoring""",
)
add_model_arguments(parser)
return parser
@ -335,6 +422,9 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -363,6 +453,13 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
set to true.
ngram_lm:
A ngram lm. Used in LODR decoding.
ngram_lm_scale:
The scale of the ngram language model.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -468,6 +565,30 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
@ -517,6 +638,9 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -535,6 +659,8 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -566,6 +692,9 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
for name, hyps in hyps_dict.items():
@ -634,6 +763,7 @@ def save_results(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -648,6 +778,8 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -675,6 +807,19 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_shallow_fusion:
if params.lm_type == "rnn":
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
elif params.lm_type == "transformer":
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -785,6 +930,34 @@ def main():
model.to(device)
model.eval()
# only load N-gram LM when needed
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
# only load the neural network LM if doing shallow fusion
if params.use_shallow_fusion:
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
@ -826,6 +999,9 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
LM=LM,
)
save_results(

View File

@ -68,3 +68,5 @@ from .utils import (
)
from .ngram_lm import NgramLm, NgramLmStateCost
from .lm_wrapper import LmScorer

254
icefall/lm_wrapper.py Normal file
View File

@ -0,0 +1,254 @@
# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import AttributeDict, str2bool
class LmScorer(torch.nn.Module):
"""This is a wrapper for NN LMs
The language models supported include:
RNN,
Transformer
"""
def __init__(
self,
lm_type: str,
params: AttributeDict,
device,
lm_scale: float = 0.3,
):
super(LmScorer, self).__init__()
assert lm_type in ["rnn", "transformer"], f"{lm_type} is not supported"
self.lm_type = lm_type
self.lm = self.get_lm(lm_type, device, params)
self.lm_scale = lm_scale
self.params = params
@classmethod
def add_arguments(cls, parser):
# LM general arguments
parser.add_argument(
"--vocab-size",
type=int,
default=500,
)
parser.add_argument(
"--lm-epoch",
type=int,
default=7,
help="""Which epoch to be used
""",
)
parser.add_argument(
"--lm-avg",
type=int,
default=1,
help="""Number of checkpoints to be averaged
""",
)
parser.add_argument("--lm-exp-dir", type=str, help="Path to LM experiments")
# Now RNNLM related arguments
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
# Now transformers
parser.add_argument(
"--transformer-lm-exp-dir", type=str, help="Directory of transformer LM exp"
)
parser.add_argument(
"--transformer-lm-dim-feedforward",
type=int,
default=2048,
help="Dimension of FFW module in transformer",
)
parser.add_argument(
"--transformer-lm-encoder-dim",
type=int,
default=768,
help="Encoder dimension of transformer",
)
parser.add_argument(
"--transformer-lm-embedding-dim",
type=int,
default=768,
help="Input embedding dimension of transformer",
)
parser.add_argument(
"--transformer-lm-nhead",
type=int,
default=8,
help="Number of attention heads in transformer",
)
parser.add_argument(
"--transformer-lm-num-layers",
type=int,
default=16,
help="Number of encoder layers in transformer",
)
parser.add_argument(
"--transformer-lm-tie-weights",
type=str2bool,
default=True,
help="If tie weights in transformer LM",
)
def get_lm(self, lm_type: str, device, params: AttributeDict) -> torch.nn.Module:
"""Return the neural network LM
Args:
lm_type (str): Type name of NN LM
"""
if lm_type == "rnn":
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.lm_avg == 1:
load_checkpoint(
f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model
)
model.to(device)
else:
start = params.lm_epoch - params.lm_avg + 1
filenames = []
for i in range(start, params.lm_epoch + 1):
if start >= 0:
filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif lm_type == "transformer":
model = TransformerLM(
vocab_size=params.vocab_size,
d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward,
nhead=params.transformer_lm_nhead,
num_layers=params.transformer_lm_num_layers,
tie_weights=params.transformer_lm_tie_weights,
params=params,
)
if params.lm_avg == 1:
load_checkpoint(
f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model
)
model.to(device)
else:
start = params.lm_epoch - params.lm_avg + 1
filenames = []
for i in range(start, params.lm_epoch + 1):
if start >= 0:
filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
raise NotImplementedError()
return model
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
"""Score the input and return the prediction
This requires the lm to have the method `score_token`
Args:
x (torch.Tensor): Input tokens
x_lens (torch.Tensor): Length of the input tokens
state (optional): LM states
"""
return self.lm.score_token(x, x_lens, state)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
LmScorer.add_arguments(parser)
args = parser.parse_args()
params = AttributeDict()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
Scorer = LmScorer(params=params, device=device)
Scorer.eval()
x = (
torch.tensor([[1, 4, 19, 256, 77], [1, 4, 19, 256, 77]])
.to(device)
.to(torch.int64)
)
x_lens = torch.tensor([5, 5]).to(device)
state = None
score, state = Scorer.score(x, x_lens)
print(score.shape)
print(score[0])
print(score[1])

View File

@ -153,9 +153,24 @@ class RnnLmModel(torch.nn.Module):
def clean_cache(self):
self.cache = {}
def score_token(self, tokens: torch.Tensor, state=None):
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
"""Score a batch of tokens
Args:
x (torch.Tensor):
A batch of tokens
x_lens (torch.Tensor):
The length of tokens in the batch before padding
state (_type_, optional):
Either None or a tuple of two torch.Tensor. Each tensor has
the shape of (hidden_dim)
Returns:
_type_: _description_
"""
device = next(self.parameters()).device
batch_size = tokens.size(0)
batch_size = x.size(0)
if state:
h, c = state
else:
@ -166,7 +181,7 @@ class RnnLmModel(torch.nn.Module):
device
)
embedding = self.input_embedding(tokens)
embedding = self.input_embedding(x)
rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out)

View File

@ -531,6 +531,9 @@ def run(rank, world_size, args):
tie_weights=params.tie_weights,
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)

View File

@ -0,0 +1,510 @@
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.transformer_lm.scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from icefall.utils import is_jit_tracing
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.01)
nn.init.normal_(self.pos_bias_v, std=0.01)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
key,
value,
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.get_weight(),
self.in_proj.get_bias(),
self.dropout,
self.out_proj.get_weight(),
self.out_proj.get_bias(),
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
left_context=left_context,
)
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
if not is_jit_tracing():
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
if is_jit_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
if not is_jit_tracing():
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
if not is_jit_tracing():
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None and not is_jit_tracing():
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
if not is_jit_tracing():
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self._pos_bias_v()).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd, left_context)
attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
if not is_jit_tracing():
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
# If we are using dynamic_chunk_training and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`, at this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if (
attn_mask is not None
and attn_mask.dtype == torch.bool
and key_padding_mask is not None
):
if attn_mask.size(0) != 1:
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
if not is_jit_tracing():
assert list(attn_output.size()) == [
bsz * num_heads,
tgt_len,
head_dim,
]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None

View File

@ -0,0 +1,195 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from pathlib import Path
import torch
from dataset import get_dataloader
from train import get_params
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=7,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transformer_lm/exp_full_libri_16layer_maxlen200_8gpu",
)
parser.add_argument(
"--lm-data",
type=str,
help="Path to the LM test data for computing perplexity",
default="transformer_lm/libri_lm_training_bpe500/sorted_lm_data-test.pt",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=16,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of RNN layers the model",
)
parser.add_argument(
"--max-sent-len",
type=int,
default=100,
help="Number of RNN layers the model",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lm_data = Path(args.lm_data)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-ppl/")
logging.info("Computing perplexity started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = TransformerLM(
vocab_size=params.vocab_size,
d_model=params.encoder_dim,
embedding_dim=params.embedding_dim,
dim_feedforward=params.dim_feedforward,
nhead=params.nhead,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
params=params,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(
[p.numel() for p in model.parameters() if p.requires_grad]
)
logging.info(f"Number of model parameters: {num_param}")
logging.info(
f"Number of model parameters (requires_grad): "
f"{num_param_requires_grad} "
f"({num_param_requires_grad/num_param_requires_grad*100}%)"
)
logging.info(f"Loading LM test data from {params.lm_data}")
test_dl = get_dataloader(
filename=params.lm_data,
is_distributed=False,
params=params,
)
tot_loss = 0.0
num_tokens = 0
num_sentences = 0
for batch_idx, batch in enumerate(test_dl):
x, y, sentence_lengths = batch
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum().cpu().item()
tot_loss += loss
num_tokens += sentence_lengths.sum().cpu().item()
num_sentences += x.size(0)
ppl = math.exp(tot_loss / num_tokens)
logging.info(
f"total nll: {tot_loss}, num tokens: {num_tokens}, "
f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../rnn_lm/dataset.py

View File

@ -0,0 +1,329 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from icefall.transformer_lm.attention import RelPositionMultiheadAttention
from icefall.transformer_lm.scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from icefall.utils import is_jit_tracing, make_pad_mask
class Transformer(torch.nn.Module):
"""_summary_
Args:
input_dim (int): Input feature dimension
d_mode (int): The dimension of the transformer
dim_feedforward (int ): The dimension of the ffw module
nhead (int): The number of attention heads
dropout_rate (float): dropout rate
att_dropout (float): dropout rate in attention module
"""
def __init__(
self,
input_dim: int,
d_model: int,
dim_feedforward: int,
nhead: int = 4,
num_layers: int = 6,
dropout_rate: float = 0.1,
att_dropout: float = 0.0,
):
super().__init__()
self.encoder_layers = num_layers
self.d_model = d_model
self.embed = ScaledLinear(input_dim, d_model)
self.norm_before = BasicNorm(d_model, learn_eps=False)
self.encoder_pos = RelPositionalEncoding(d_model, dropout_rate)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
dim_feedforward=dim_feedforward,
nhead=nhead,
dropout_rate=dropout_rate,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
def _create_attention_mask(self, x_lens: torch.Tensor):
# create a 2D attention mask to mask out
# the upper right half of the attention matrix
max_len = max(x_lens)
ones = torch.ones(max_len, max_len, device=x_lens.device, dtype=torch.bool)
return torch.triu(ones, diagonal=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Transformer forward
Args:
x (torch.Tensor): Input tensor (B,T,input_dim)
x_lens (torch.Tensor): The length of input tensors before padding (B,)
Returns:
Return a tuple of 2 tensors:
- x: output feature of the transformer (B,T,d_model)
- x_lens: output feature lens of the transformer
"""
attention_mask = self._create_attention_mask(x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = self.norm_before(self.embed(x))
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2)
x = self.encoder(
x,
pos_emb,
mask=attention_mask, # pass the attention mast
src_key_padding_mask=src_key_padding_mask,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, x_lens
class TransformerEncoder(torch.nn.Module):
def __init__(self, encoder_layer: torch.nn.Module, num_layers: int) -> None:
"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer (torch.nn.Module): an instance of the TransformerEncoderLayer()
num_layers (int): Number of layers to be stacked
"""
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
def forward(
self,
src: torch.Tensor,
pos_emb: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""_summary_
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
output: transformer encoded features
"""
output = src
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
src_key_padding_mask=src_key_padding_mask,
src_mask=mask,
)
return output
class TransformerEncoderLayer(torch.nn.Module):
def __init__(
self,
d_model: int,
dim_feedforward: int,
nhead: int,
dropout_rate: float,
):
"""TransformerEncoderLayer is made up of self-attn and feedforward module
Args:
d_model (int): The model size
dim_feedforward (int): Dimension of ffw module
nhead (int): Number of heads
dropout_rate (float): Dropout rate
"""
super().__init__()
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout_rate),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.norm_final = BasicNorm(d_model)
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
src: torch.Tensor,
pos_emb: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
cache=None,
):
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_key_padding_mask: the mask for the src keys per batch (optional).
src_mask: the mask for the src sequence (optional).
"""
src_orig = src
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = src + self.dropout(src_att)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
return src
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
if is_jit_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + left_context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(
self,
x: torch.Tensor,
left_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x, left_context)
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)

View File

@ -0,0 +1,186 @@
#!/usr/bin/env python3
# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from model import TransformerLM
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, load_averaged_model, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=11,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=768,
help="Embedding dim of the model",
)
parser.add_argument(
"--encoder-dim",
type=int,
default=768,
help="Encoder dim of the model",
)
parser.add_argument(
"--dim_feedforward",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--nhead",
type=int,
default=8,
help="Number of attention heads",
)
parser.add_argument(
"--num-layers",
type=int,
default=16,
help="Number of Transformer layers",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = AttributeDict({})
params.update(vars(args))
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("About to create model")
model = TransformerLM(
vocab_size=params.vocab_size,
d_model=params.encoder_dim,
embedding_dim=params.embedding_dim,
dim_feedforward=params.dim_feedforward,
nhead=params.nhead,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
params=params,
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,115 @@
# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from icefall.transformer_lm.encoder import Transformer
from icefall.utils import AttributeDict, add_eos, add_sos, make_pad_mask
class TransformerLM(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
d_model: int,
dim_feedforward: int,
nhead: int = 8,
num_layers: int = 16,
tie_weights: bool = True,
dropout: float = 0.1,
emb_dropout_rate: float = 0.0,
params: AttributeDict = None,
):
super().__init__()
self.vocab_size = vocab_size
self.params = params
self.input_embedding = torch.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
)
self.encoder = Transformer(
input_dim=embedding_dim,
d_model=d_model,
dim_feedforward=dim_feedforward,
nhead=nhead,
num_layers=num_layers,
dropout_rate=dropout,
)
self.output_linear = torch.nn.Linear(
in_features=d_model, out_features=vocab_size
)
if tie_weights:
logging.info("Tying weights")
assert d_model == embedding_dim, (d_model, embedding_dim)
self.output_linear.weight = self.input_embedding.weight
else:
logging.info("Not tying weights")
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
x_lens: torch.Tensor,
return_logits: bool = False,
):
"""Forward transformer language model
Args:
x (torch.Tensor): Input tokens (B,L)
y (torch.Tensor): Output tokens (with EOS appended) (B,L)
x_lens (torch.Tensor): Length of input tokens before padding (B,)
return_logits (bool, optional): Return logits instead of NLL
"""
x = self.input_embedding(x)
x, x_lens = self.encoder(x, x_lens)
logits = self.output_linear(x)
if return_logits:
return logits
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
mask = make_pad_mask(x_lens).reshape(-1)
nll_loss.masked_fill_(mask, 0)
return nll_loss
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
bs = x.size(0)
state = None
logits = self.forward(x, x, x_lens, return_logits=True)
index = torch.arange(bs)
last_logits = logits[index, x_lens - 1, :]
return last_logits.log_softmax(-1), state

View File

@ -0,0 +1 @@
../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1,609 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./transformer_lm/train.py \
--start-epoch 0 \
--world-size 2 \
--num-epochs 1 \
--use-fp16 0 \
--num-layers 12 \
--batch-size 400
"""
import argparse
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from dataset import get_dataloader
from lhotse.utils import fix_random_seed
from model import TransformerLM
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transformer_lm/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, logs, etc, are saved
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=True,
help="Whether to use half precision training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=400,
)
parser.add_argument(
"--lm-data",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data.pt",
help="LM training data",
)
parser.add_argument(
"--lm-data-valid",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
help="LM validation data",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=12,
help="Number of Transformer layers in the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters."""
params = AttributeDict(
{
"max_sent_len": 200,
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
"lr": 1e-3,
"weight_decay": 1e-6,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 200,
"reset_interval": 2000,
"valid_interval": 1000,
"nhead": 8,
"embedding_dim": 768,
"encoder_dim": 768,
"dim_feedforward": 2048,
"dropout": 0.1,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
sentence_lengths: torch.Tensor,
is_training: bool,
) -> Tuple[torch.Tensor, MetricsTracker]:
"""Compute the negative log-likelihood loss given a model and its input.
Args:
model:
The NN model,
x:
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
each row starts with SOS ID.
y:
A 2-D tensor. Each row is a shifted version of the corresponding row
in `x` but ends with an EOS ID (before padding).
sentence_lengths:
A 1-D tensor containing number of tokens of each sentence
before padding.
is_training:
True for training. False for validation.
"""
with torch.set_grad_enabled(is_training):
device = model.device
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum()
num_tokens = sentence_lengths.sum().item()
loss_info = MetricsTracker()
# Note: Due to how MetricsTracker() is designed,
# we use "frames" instead of "num_tokens" as a key here
loss_info["frames"] = num_tokens
loss_info["loss"] = loss.detach().item()
return loss, loss_info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all sentences is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
f"batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer.add_scalar(
"train/current_ppl", this_batch_ppl, params.batch_idx_train
)
tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
f"ppl: {valid_ppl}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/valid_ppl", valid_ppl, params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
is_distributed = world_size > 1
fix_random_seed(params.seed)
if is_distributed:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = TransformerLM(
vocab_size=params.vocab_size,
d_model=params.encoder_dim,
embedding_dim=params.embedding_dim,
dim_feedforward=params.dim_feedforward,
nhead=params.nhead,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
params=params,
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if is_distributed:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")
optimizer.load_state_dict(checkpoints["optimizer"])
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=is_distributed,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=is_distributed,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if is_distributed:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if is_distributed:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()