RNNLM rescore + Low-order density ratio (#1017)

* add rnnlm rescore + LODR

* add LODR in decode.py

* update RESULTS
This commit is contained in:
marcoyang1998 2023-04-24 15:00:02 +08:00 committed by GitHub
parent 2096e69bda
commit 45c13e90e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 345 additions and 10 deletions

View File

@ -215,11 +215,12 @@ done
We also support decoding with neural network LMs. After combining with language models, the WERs are We also support decoding with neural network LMs. After combining with language models, the WERs are
| decoding method | chunk size | test-clean | test-other | comment | decoding mode | | decoding method | chunk size | test-clean | test-other | comment | decoding mode |
|----------------------|------------|------------|------------|---------------------|----------------------| |----------------------|------------|------------|------------|---------------------|----------------------|
| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | | `modified_beam_search` | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming |
| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | | `modified_beam_search_lm_shallow_fusion` | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming |
| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | | `modified_beam_search_lm_rescore` | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming |
| `modified_beam_search_lm_rescore_LODR` | 320ms | 2.52 | 6.73 | --epoch 30 --avg 9 | simulated streaming |
Please use the following command for RNNLM shallow fusion: Please use the following command for `modified_beam_search_lm_shallow_fusion`:
```bash ```bash
for lm_scale in $(seq 0.15 0.01 0.38); do for lm_scale in $(seq 0.15 0.01 0.38); do
for beam_size in 4 8 12; do for beam_size in 4 8 12; do
@ -246,7 +247,7 @@ for lm_scale in $(seq 0.15 0.01 0.38); do
done done
``` ```
Please use the following command for RNNLM rescore: Please use the following command for `modified_beam_search_lm_rescore`:
```bash ```bash
./pruned_transducer_stateless7_streaming/decode.py \ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 30 \ --epoch 30 \
@ -268,7 +269,32 @@ Please use the following command for RNNLM rescore:
--lm-vocab-size 500 --lm-vocab-size 500
``` ```
A well-trained RNNLM can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>. Please use the following command for `modified_beam_search_lm_rescore_LODR`:
```bash
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 30 \
--avg 9 \
--use-averaged-model True \
--beam-size 8 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore_LODR \
--use-shallow-fusion 0 \
--lm-type rnn \
--lm-exp-dir rnn_lm/exp \
--lm-epoch 99 \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500 \
--tokens-ngram 2 \
--backoff-id 500
```
A well-trained RNNLM can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>. The bi-gram used in LODR decoding
can be found here: <https://huggingface.co/marcoyang/librispeech_bigram>.
#### Smaller model #### Smaller model

View File

@ -1244,7 +1244,7 @@ def modified_beam_search_lm_rescore(
# get the best hyp with different lm_scale # get the best hyp with different lm_scale
for lm_scale in lm_scale_list: for lm_scale in lm_scale_list:
key = f"nnlm_scale_{lm_scale}" key = f"nnlm_scale_{lm_scale:.2f}"
tot_scores = am_scores.values + lm_scores * lm_scale tot_scores = am_scores.values + lm_scores * lm_scale
ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores)
max_indexes = ragged_tot_scores.argmax().tolist() max_indexes = ragged_tot_scores.argmax().tolist()
@ -1257,6 +1257,222 @@ def modified_beam_search_lm_rescore(
return ans return ans
def modified_beam_search_lm_rescore_LODR(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
LM: LmScorer,
LODR_lm: NgramLm,
sp: spm.SentencePieceProcessor,
lm_scale_list: List[int],
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Rescore the final results with RNNLM and return the one with the highest score
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
LM:
A neural network language model
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp)
B = B + finalized_B
# get the am_scores for n-best list
hyps_shape = get_hyps_shape(B)
am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b])
am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device)
# now LM rescore
# prepare input data to LM
candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b]
possible_seqs = k2.RaggedTensor(candidate_seqs)
row_splits = possible_seqs.shape.row_splits(1)
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1)
possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1)
sentence_token_lengths += 1
x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id)
y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id)
x = x.to(device).to(torch.int64)
y = y.to(device).to(torch.int64)
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
assert lm_scores.ndim == 2
lm_scores = -1 * lm_scores.sum(dim=1)
# now LODR scores
import math
LODR_scores = []
for seq in candidate_seqs:
tokens = " ".join(sp.id_to_piece(seq))
LODR_scores.append(LODR_lm.score(tokens))
LODR_scores = torch.tensor(LODR_scores).to(device) * math.log(
10
) # arpa scores are 10-based
assert lm_scores.shape == LODR_scores.shape
ans = {}
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
LODR_scale_list = [0.05 * i for i in range(1, 20)]
# get the best hyp with different lm_scale and lodr_scale
for lm_scale in lm_scale_list:
for lodr_scale in LODR_scale_list:
key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}"
tot_scores = (
am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale
)
ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores)
max_indexes = ragged_tot_scores.argmax().tolist()
unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes]
hyps = []
for idx in unsorted_indices:
hyps.append(unsorted_hyps[idx])
ans[key] = hyps
return ans
def _deprecated_modified_beam_search( def _deprecated_modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,

View File

@ -123,10 +123,13 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_lm_rescore, modified_beam_search_lm_rescore,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -134,7 +137,6 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.lm_wrapper import LmScorer
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -336,6 +338,21 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""The order of the ngram lm.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -349,6 +366,8 @@ def decode_one_batch(
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -483,6 +502,18 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore": elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)] lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore( ans_dict = modified_beam_search_lm_rescore(
@ -493,6 +524,18 @@ def decode_one_batch(
LM=LM, LM=LM,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
) )
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
ans_dict = modified_beam_search_lm_rescore_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -531,7 +574,10 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
elif params.decoding_method == "modified_beam_search_lm_rescore": elif params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict() ans = dict()
assert ans_dict is not None assert ans_dict is not None
for key, hyps in ans_dict.items(): for key, hyps in ans_dict.items():
@ -550,6 +596,8 @@ def decode_dataset(
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -568,6 +616,8 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
ngram_lm:
A n-gram LM to be used for LODR.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -600,6 +650,8 @@ def decode_dataset(
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
LM=LM, LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -677,8 +729,10 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_shallow_fusion",
"modified_beam_search_lm_rescore", "modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -822,7 +876,12 @@ def main():
model.eval() model.eval()
# only load the neural network LM if required # only load the neural network LM if required
if params.use_shallow_fusion or "lm" in params.decoding_method: if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
):
LM = LmScorer( LM = LmScorer(
lm_type=params.lm_type, lm_type=params.lm_type,
params=params, params=params,
@ -834,6 +893,35 @@ def main():
else: else:
LM = None LM = None
# only load N-gram LM when needed
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
try:
import kenlm
except ImportError:
print("Please install kenlm first. You can use")
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
print("to install it")
import sys
sys.exit(-1)
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name)
elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"Loading token level lm: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
@ -866,8 +954,10 @@ def main():
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_clean_dl, test_other_dl]
import time
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
start = time.time()
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
@ -876,7 +966,10 @@ def main():
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
LM=LM, LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
) )
logging.info(f"Elasped time for {test_set}: {time.time() - start}")
save_results( save_results(
params=params, params=params,