mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
[recipe] LibriSpeech zipformer_ctc (#941)
* merge upstream * initial commit for zipformer_ctc * remove unwanted changes * remove changes to other recipe * fix zipformer softlink * fix for JIT export * add missing file * fix symbolic links * update results * Update RESULTS.md Address comments from @csukuangfj --------- Co-authored-by: zr_jin <peter.jin.cn@gmail.com>
This commit is contained in:
parent
5cebecf2dc
commit
7d56685734
@ -47,6 +47,7 @@ We place an additional Conv1d layer right after the input embedding layer.
|
||||
| `conformer-ctc` | Conformer | Use auxiliary attention head |
|
||||
| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head |
|
||||
| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty |
|
||||
| `zipformer-ctc` | Zipformer | Use auxiliary attention head |
|
||||
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head | The latest recipe |
|
||||
|
||||
# MMI
|
||||
|
@ -375,6 +375,55 @@ for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
done
|
||||
```
|
||||
|
||||
### Zipformer CTC
|
||||
|
||||
#### [zipformer_ctc](./zipformer_ctc)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/941> for more details.
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding
|
||||
results at:
|
||||
<https://huggingface.co/desh2608/icefall-asr-librispeech-zipformer-ctc>
|
||||
|
||||
Number of model parameters: 86083707, i.e., 86.08 M
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|-------------------------|------------|------------|---------------------|
|
||||
| ctc-decoding | 2.50 | 5.86 | --epoch 30 --avg 9 |
|
||||
| whole-lattice-rescoring | 2.44 | 5.38 | --epoch 30 --avg 9 |
|
||||
| attention-rescoring | 2.35 | 5.16 | --epoch 30 --avg 9 |
|
||||
| 1best | 2.01 | 4.61 | --epoch 30 --avg 9 |
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./zipformer_ctc/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer_ctc/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 1000 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The tensorboard log can be found at:
|
||||
<https://tensorboard.dev/experiment/IjPSJjHOQFKPYA5Z0Vf8wg>
|
||||
|
||||
The decoding command is:
|
||||
|
||||
```bash
|
||||
./zipformer_ctc/decode.py \
|
||||
--epoch 30 --avg 9 --use-averaged-model True \
|
||||
--exp-dir zipformer_ctc/exp \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--lm-dir data/lm \
|
||||
--method ctc-decoding
|
||||
```
|
||||
|
||||
### pruned_transducer_stateless7 (Fine-tune with mux)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1059> for more details.
|
||||
@ -616,7 +665,6 @@ for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
#### Smaller model
|
||||
|
||||
We also provide a very small version (only 6.1M parameters) of this setup. The training command for the small model is:
|
||||
@ -663,6 +711,7 @@ This small model achieves the following WERs on GigaSpeech test and dev sets:
|
||||
|
||||
You can find the tensorboard logs at <https://tensorboard.dev/experiment/tAc5iXxTQrCQxky5O5OLyw/#scalars>.
|
||||
|
||||
|
||||
### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)
|
||||
|
||||
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
|
||||
|
0
egs/librispeech/ASR/zipformer_ctc/__init__.py
Normal file
0
egs/librispeech/ASR/zipformer_ctc/__init__.py
Normal file
1
egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../tdnn_lstm_ctc/asr_datamodule.py
|
886
egs/librispeech/ASR/zipformer_ctc/decode.py
Executable file
886
egs/librispeech/ASR/zipformer_ctc/decode.py
Executable file
@ -0,0 +1,886 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_ctc_model, get_params
|
||||
from transformer import encoder_padding_mask
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_rnn_lm,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
load_averaged_model,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=77,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=55,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="attention-decoder",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer_ctc/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The n-gram LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
rnn_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
rnn_lm_model:
|
||||
The neural model for RNN LM.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
nnet_output, _ = model.encoder(feature, feature_lens)
|
||||
ctc_output = model.ctc_output(nnet_output)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // params.subsampling_factor,
|
||||
supervisions["num_frames"] // params.subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
|
||||
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
|
||||
mask = mask.to(nnet_output.device) if mask is not None else None
|
||||
mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=mmodel,
|
||||
memory=nnet_output,
|
||||
memory_key_padding_mask=mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.method == "rnn-lm":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
|
||||
best_path_dict = rescore_with_rnn_lm(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
model=mmodel,
|
||||
memory=nnet_output,
|
||||
memory_key_padding_mask=mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=0,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
ans = None
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
rnn_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
rnn_lm_model:
|
||||
The neural model for RNN LM.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
if hyps_dict is not None:
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
assert len(results) > 0, "It should not decode to empty in the first batch!"
|
||||
this_batch = []
|
||||
hyp_words = []
|
||||
for ref_text in texts:
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
for lm_scale in results.keys():
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]],
|
||||
):
|
||||
if params.method in ("attention-decoder", "rnn-lm"):
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.lm_dir = Path(args.lm_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
params.vocab_size = num_classes
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
params.num_classes = num_classes
|
||||
params.sos_id = sos_id
|
||||
params.eos_id = eos_id
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set G.properties to None
|
||||
G.__dict__["_properties"] = None
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
# Save a dummy value so that it can be loaded in C++.
|
||||
# See https://github.com/pytorch/pytorch/issues/67902
|
||||
# for why we need to do this.
|
||||
G.dummy = 1
|
||||
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in [
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_ctc_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
rnn_lm_model = None
|
||||
if params.method == "rnn-lm":
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.num_classes,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
if params.rnn_lm_avg == 1:
|
||||
load_checkpoint(
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
rnn_lm_model,
|
||||
)
|
||||
rnn_lm_model.to(device)
|
||||
else:
|
||||
rnn_lm_model = load_averaged_model(
|
||||
params.rnn_lm_exp_dir,
|
||||
rnn_lm_model,
|
||||
params.rnn_lm_epoch,
|
||||
params.rnn_lm_avg,
|
||||
device,
|
||||
)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
298
egs/librispeech/ASR/zipformer_ctc/decoder.py
Normal file
298
egs/librispeech/ASR/zipformer_ctc/decoder.py
Normal file
@ -0,0 +1,298 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformer import PositionalEncoding, TransformerDecoderLayer
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class implements Transformer based decoder for an attention-based encoder-decoder
|
||||
model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_classes: int,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_layers:
|
||||
Number of layers.
|
||||
num_classes:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
d_model:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if num_layers > 0:
|
||||
self.decoder_num_class = num_classes # bpe model already has sos/eos symbol
|
||||
|
||||
self.decoder_embed = nn.Embedding(
|
||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
||||
)
|
||||
self.decoder_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
if normalize_before:
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
decoder_norm = None
|
||||
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=num_layers,
|
||||
norm=decoder_norm,
|
||||
)
|
||||
|
||||
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
|
||||
self.decoder_criterion = LabelSmoothingLoss()
|
||||
else:
|
||||
self.decoder_criterion = None
|
||||
|
||||
@torch.jit.export
|
||||
def forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
||||
The IDs can be either phone IDs or word piece IDs.
|
||||
sos_id:
|
||||
sos token id
|
||||
eos_id:
|
||||
eos token id
|
||||
Returns:
|
||||
A scalar, the **sum** of label smoothing loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device)
|
||||
ys_out_pad = ys_out_pad.to(device)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
pred_pad = self.decoder(
|
||||
tgt=tgt,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
) # (T, N, C)
|
||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
|
||||
|
||||
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
|
||||
|
||||
return decoder_loss
|
||||
|
||||
@torch.jit.export
|
||||
def decoder_nll(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
token_ids: List[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
A list-of-list IDs (e.g., word piece IDs).
|
||||
Each sublist represents an utterance.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
Returns:
|
||||
A 2-D tensor of shape (len(token_ids), max_token_length)
|
||||
representing the cross entropy loss (i.e., negative log-likelihood).
|
||||
"""
|
||||
# The common part between this function and decoder_forward could be
|
||||
# extracted as a separate function.
|
||||
if isinstance(token_ids[0], torch.Tensor):
|
||||
# This branch is executed by torchscript in C++.
|
||||
# See https://github.com/k2-fsa/k2/pull/870
|
||||
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
|
||||
token_ids = [tolist(t) for t in token_ids]
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
pred_pad = self.decoder(
|
||||
tgt=tgt,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
) # (T, B, F)
|
||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
|
||||
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
|
||||
# nll: negative log-likelihood
|
||||
nll = torch.nn.functional.cross_entropy(
|
||||
pred_pad.view(-1, self.decoder_num_class),
|
||||
ys_out_pad.view(-1),
|
||||
ignore_index=-1,
|
||||
reduction="none",
|
||||
)
|
||||
|
||||
nll = nll.view(pred_pad.shape[0], -1)
|
||||
|
||||
return nll
|
||||
|
||||
|
||||
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
||||
"""Prepend sos_id to each utterance.
|
||||
Args:
|
||||
token_ids:
|
||||
A list-of-list of token IDs. Each sublist contains
|
||||
token IDs (e.g., word piece IDs) of an utterance.
|
||||
sos_id:
|
||||
The ID of the SOS token.
|
||||
Return:
|
||||
Return a new list-of-list, where each sublist starts
|
||||
with SOS ID.
|
||||
"""
|
||||
return [[sos_id] + utt for utt in token_ids]
|
||||
|
||||
|
||||
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
"""Append eos_id to each utterance.
|
||||
Args:
|
||||
token_ids:
|
||||
A list-of-list of token IDs. Each sublist contains
|
||||
token IDs (e.g., word piece IDs) of an utterance.
|
||||
eos_id:
|
||||
The ID of the EOS token.
|
||||
Return:
|
||||
Return a new list-of-list, where each sublist ends
|
||||
with EOS ID.
|
||||
"""
|
||||
return [utt + [eos_id] for utt in token_ids]
|
||||
|
||||
|
||||
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
|
||||
"""Generate a length mask for input.
|
||||
The masked position are filled with True,
|
||||
Unmasked positions are filled with False.
|
||||
Args:
|
||||
ys_pad:
|
||||
padded tensor of dimension (batch_size, input_length).
|
||||
ignore_id:
|
||||
the ignored number (the padding number) in ys_pad
|
||||
Returns:
|
||||
Tensor:
|
||||
a bool tensor of the same shape as the input tensor.
|
||||
"""
|
||||
ys_mask = ys_pad == ignore_id
|
||||
return ys_mask
|
||||
|
||||
|
||||
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
||||
"""Generate a square mask for the sequence. The masked positions are
|
||||
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
||||
The mask can be used for masked self-attention.
|
||||
For instance, if sz is 3, it returns::
|
||||
tensor([[0., -inf, -inf],
|
||||
[0., 0., -inf],
|
||||
[0., 0., 0]])
|
||||
Args:
|
||||
sz: mask size
|
||||
Returns:
|
||||
A square mask of dimension (sz, sz)
|
||||
"""
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = (
|
||||
mask.float()
|
||||
.masked_fill(mask == 0, float("-inf"))
|
||||
.masked_fill(mask == 1, float(0.0))
|
||||
)
|
||||
return mask
|
||||
|
||||
|
||||
def tolist(t: torch.Tensor) -> List[int]:
|
||||
"""Used by jit"""
|
||||
return torch.jit.annotate(List[int], t.tolist())
|
1
egs/librispeech/ASR/zipformer_ctc/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/encoder_interface.py
|
240
egs/librispeech/ASR/zipformer_ctc/export.py
Executable file
240
egs/librispeech/ASR/zipformer_ctc/export.py
Executable file
@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_ctc_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer_ctc/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="""It contains language related input files such as "lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
params.vocab_size = num_classes
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = get_ctc_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/zipformer_ctc/label_smoothing.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../conformer_ctc/label_smoothing.py
|
158
egs/librispeech/ASR/zipformer_ctc/model.py
Normal file
158
egs/librispeech/ASR/zipformer_ctc/model.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from transformer import encoder_padding_mask
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.utils import encode_supervisions
|
||||
|
||||
|
||||
class CTCModel(nn.Module):
|
||||
"""It implements a CTC model with an auxiliary attention head."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
encoder_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
An instance of `EncoderInterface`. The shared encoder for the CTC and attention
|
||||
branches
|
||||
decoder:
|
||||
An instance of `nn.Module`. This is the decoder for the attention branch.
|
||||
encoder_dim:
|
||||
Dimension of the encoder output.
|
||||
decoder_dim:
|
||||
Dimension of the decoder output.
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
self.encoder = encoder
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
self.decoder = decoder
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
supervisions: torch.Tensor,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
subsampling_factor: int = 1,
|
||||
beam_size: int = 10,
|
||||
reduction: str = "sum",
|
||||
use_double_scores: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
Tensor of dimension (N, T, C) where N is the batch size,
|
||||
T is the number of frames, and C is the feature dimension.
|
||||
x_lens:
|
||||
Tensor of dimension (N,) where N is the batch size.
|
||||
supervisions:
|
||||
Supervisions are used in training.
|
||||
graph_compiler:
|
||||
It is used to compile a decoding graph from texts.
|
||||
subsampling_factor:
|
||||
It is used to compute the `supervisions` for the encoder.
|
||||
beam_size:
|
||||
Beam size used in `k2.ctc_loss`.
|
||||
reduction:
|
||||
Reduction method used in `k2.ctc_loss`.
|
||||
use_double_scores:
|
||||
If True, use double precision in `k2.ctc_loss`.
|
||||
Returns:
|
||||
Return the CTC loss, attention loss, and the total number of frames.
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
|
||||
nnet_output, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(nnet_output)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=subsampling_factor
|
||||
)
|
||||
num_frames = supervision_segments[:, 2].sum().item()
|
||||
|
||||
# Works with a BPE model
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
ctc_output,
|
||||
supervision_segments.cpu(),
|
||||
allow_truncate=subsampling_factor - 1,
|
||||
)
|
||||
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=beam_size,
|
||||
reduction=reduction,
|
||||
use_double_scores=use_double_scores,
|
||||
)
|
||||
|
||||
if self.decoder is not None:
|
||||
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
mmodel = (
|
||||
self.decoder.module if hasattr(self.decoder, "module") else self.decoder
|
||||
)
|
||||
# Note: We need to generate an unsorted version of token_ids
|
||||
# `encode_supervisions()` called above sorts text, but
|
||||
# encoder_memory and memory_mask are not sorted, so we
|
||||
# use an unsorted version `supervisions["text"]` to regenerate
|
||||
# the token_ids
|
||||
#
|
||||
# See https://github.com/k2-fsa/icefall/issues/97
|
||||
# for more details
|
||||
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
|
||||
mask = mask.to(nnet_output.device) if mask is not None else None
|
||||
att_loss = mmodel.forward(
|
||||
nnet_output,
|
||||
mask,
|
||||
token_ids=unsorted_token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = torch.tensor([0])
|
||||
|
||||
return ctc_loss, att_loss, num_frames
|
1
egs/librispeech/ASR/zipformer_ctc/optim.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/optim.py
|
1
egs/librispeech/ASR/zipformer_ctc/scaling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling.py
|
1
egs/librispeech/ASR/zipformer_ctc/scaling_converter.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling_converter.py
|
1
egs/librispeech/ASR/zipformer_ctc/subsampling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../conformer_ctc/subsampling.py
|
1135
egs/librispeech/ASR/zipformer_ctc/train.py
Executable file
1135
egs/librispeech/ASR/zipformer_ctc/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/zipformer_ctc/transformer.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/transformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../conformer_ctc/transformer.py
|
1
egs/librispeech/ASR/zipformer_ctc/zipformer.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user