From 88067f7566d2ddda44bbc3a45627bc5dbceb7f54 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 8 Aug 2023 11:28:21 +0800 Subject: [PATCH] Add context biasing for zipformer recipe --- egs/librispeech/ASR/zipformer/decode.py | 80 +++++++++++++++++++++---- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 2cc157e7a..1da5b2669 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -97,6 +97,7 @@ Usage: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -122,7 +123,7 @@ from beam_search import ( ) from train import add_model_arguments, get_model, get_params -from icefall import LmScorer, NgramLm +from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -347,6 +348,25 @@ def get_parser(): help="ID of the backoff symbol in the ngram LM", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, + ) add_model_arguments(parser) return parser @@ -359,6 +379,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, LM: Optional[LmScorer] = None, ngram_lm=None, ngram_lm_scale: float = 0.0, @@ -493,6 +514,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -503,6 +525,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, LM=LM, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -515,6 +538,7 @@ def decode_one_batch( LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -527,6 +551,7 @@ def decode_one_batch( beam=params.beam_size, LM=LM, lm_scale_list=lm_scale_list, + context_graph=context_graph, ) elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": lm_scale_list = [0.02 * i for i in range(2, 30)] @@ -539,6 +564,7 @@ def decode_one_batch( LODR_lm=ngram_lm, sp=sp, lm_scale_list=lm_scale_list, + context_graph=context_graph, ) else: batch_size = encoder_out.size(0) @@ -578,16 +604,24 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} - elif params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"beam_size_{params.beam_size}_{key}"] = hyps - return ans + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + else: + prefix += "-no-context-words" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + return {prefix: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -599,6 +633,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, LM: Optional[LmScorer] = None, ngram_lm=None, ngram_lm_scale: float = 0.0, @@ -649,6 +684,7 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + context_graph=context_graph, word_table=word_table, batch=batch, LM=LM, @@ -744,6 +780,11 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: @@ -770,6 +811,10 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-context-words" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -952,6 +997,18 @@ def main(): decoding_graph = None word_table = None + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -976,6 +1033,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + context_graph=context_graph, LM=LM, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale,