Merge branch 'k2-fsa:master' into dev_swbd

This commit is contained in:
zr_jin 2023-09-13 09:39:45 +08:00 committed by GitHub
commit c78aabf3b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1737 additions and 67 deletions

View File

@ -29,6 +29,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
ls -lh data/fbank ls -lh data/fbank
ls -lh pruned_transducer_stateless2/exp ls -lh pruned_transducer_stateless2/exp
ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz
ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz
log "Decoding dev and test" log "Decoding dev and test"
# use a small value for decoding with CPU # use a small value for decoding with CPU

View File

@ -45,7 +45,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -44,7 +44,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -43,7 +43,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -43,7 +43,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -43,7 +43,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -34,7 +34,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9] python-version: [3.8]
fail-fast: false fail-fast: false

View File

@ -71,9 +71,12 @@ As the initial step, let's download the pre-trained model.
.. code-block:: bash .. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt" $ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command: To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:

View File

@ -34,9 +34,12 @@ As the initial step, let's download the pre-trained model.
.. code-block:: bash .. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt" $ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
As usual, we first test the model's performance without external LM. This can be done via the following command: As usual, we first test the model's performance without external LM. This can be done via the following command:

View File

@ -32,9 +32,12 @@ As the initial step, let's download the pre-trained model.
.. code-block:: bash .. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt" $ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command: To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:

View File

@ -1 +0,0 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -2389,6 +2389,7 @@ def modified_beam_search_LODR(
LODR_lm_scale: float, LODR_lm_scale: float,
LM: LmScorer, LM: LmScorer,
beam: int = 4, beam: int = 4,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]: ) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with """This function implements LODR (https://arxiv.org/abs/2203.16776) with
`modified_beam_search`. It uses a bi-gram language model as the estimate `modified_beam_search`. It uses a bi-gram language model as the estimate
@ -2457,6 +2458,7 @@ def modified_beam_search_LODR(
state_cost=NgramLmStateCost( state_cost=NgramLmStateCost(
LODR_lm LODR_lm
), # state of the source domain ngram ), # state of the source domain ngram
context_state=None if context_graph is None else context_graph.root,
) )
) )
@ -2602,8 +2604,17 @@ def modified_beam_search_LODR(
hyp_log_prob = topk_log_probs[k] # get score of current hyp hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
ys.append(new_token) ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token) state_cost = hyp.state_cost.forward_one_step(new_token)
@ -2619,6 +2630,7 @@ def modified_beam_search_LODR(
hyp_log_prob += ( hyp_log_prob += (
lm_score[new_token] * lm_scale lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score + LODR_lm_scale * current_ngram_score
+ context_score
) # add the lm score ) # add the lm score
lm_score = scores[count] lm_score = scores[count]
@ -2637,10 +2649,31 @@ def modified_beam_search_LODR(
state=state, state=state,
lm_score=lm_score, lm_score=lm_score,
state_cost=state_cost, state_cost=state_cost,
context_state=new_context_state,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -26,7 +26,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 --avg 9
@ -52,12 +52,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from alignment import batch_force_alignment from alignment import batch_force_alignment
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
from lhotse import CutSet from lhotse import CutSet
from lhotse.serialization import SequentialJsonlWriter from lhotse.serialization import SequentialJsonlWriter
from lhotse.supervision import AlignmentItem from lhotse.supervision import AlignmentItem
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
def get_parser(): def get_parser():

View File

@ -97,6 +97,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -122,7 +123,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -215,6 +216,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- modified_beam_search_LODR
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
@ -251,7 +253,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding-method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -285,7 +287,7 @@ def get_parser():
type=int, type=int,
default=1, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding-method is greedy_search""",
) )
parser.add_argument( parser.add_argument(
@ -347,6 +349,27 @@ def get_parser():
help="ID of the backoff symbol in the ngram LM", help="ID of the backoff symbol in the ngram LM",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -359,6 +382,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -388,7 +412,7 @@ def decode_one_batch(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM: LM:
A neural network language model. A neural network language model.
@ -493,6 +517,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -515,6 +540,7 @@ def decode_one_batch(
LODR_lm=ngram_lm, LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale, LODR_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -578,16 +604,22 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
elif params.decoding_method in ( elif "modified_beam_search" in params.decoding_method:
"modified_beam_search_lm_rescore", prefix = f"beam_size_{params.beam_size}"
"modified_beam_search_lm_rescore_LODR", if params.decoding_method in (
): "modified_beam_search_lm_rescore",
ans = dict() "modified_beam_search_lm_rescore_LODR",
assert ans_dict is not None ):
for key, hyps in ans_dict.items(): ans = dict()
hyps = [sp.decode(hyp).split() for hyp in hyps] assert ans_dict is not None
ans[f"beam_size_{params.beam_size}_{key}"] = hyps for key, hyps in ans_dict.items():
return ans hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -599,6 +631,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -618,7 +651,7 @@ def decode_dataset(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
@ -649,6 +682,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
LM=LM, LM=LM,
@ -744,6 +778,11 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
@ -770,6 +809,12 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -952,6 +997,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -976,6 +1033,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM, LM=LM,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,

View File

@ -28,6 +28,7 @@ from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWri
# even when we are not invoking the main (e.g. when spawning subprocesses). # even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_wenetspeech_dev_test(): def compute_fbank_wenetspeech_dev_test():

View File

@ -37,6 +37,7 @@ from lhotse import (
# even when we are not invoking the main (e.g. when spawning subprocesses). # even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def get_parser(): def get_parser():

View File

@ -29,7 +29,7 @@ class ContextState:
token: int, token: int,
token_score: float, token_score: float,
node_score: float, node_score: float,
local_node_score: float, output_score: float,
is_end: bool, is_end: bool,
): ):
"""Create a ContextState. """Create a ContextState.
@ -40,16 +40,15 @@ class ContextState:
The id of the root node is always 0. The id of the root node is always 0.
token: token:
The token id. The token id.
score: token_score:
The bonus for each token during decoding, which will hopefully The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search. boost the token up to survive beam search.
node_score: node_score:
The accumulated bonus from root of graph to current node, it will be The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc. used to calculate the score for fail arc.
local_node_score: output_score:
The accumulated bonus from last ``end_node``(node with is_end true) The total scores of matched phrases, sum of the node_score of all
to current_node, it will be used to calculate the score for fail arc. the output node for current node.
Node: The local_node_score of a ``end_node`` is 0.
is_end: is_end:
True if current token is the end of a context. True if current token is the end of a context.
""" """
@ -57,7 +56,7 @@ class ContextState:
self.token = token self.token = token
self.token_score = token_score self.token_score = token_score
self.node_score = node_score self.node_score = node_score
self.local_node_score = local_node_score self.output_score = output_score
self.is_end = is_end self.is_end = is_end
self.next = {} self.next = {}
self.fail = None self.fail = None
@ -93,7 +92,7 @@ class ContextGraph:
token=-1, token=-1,
token_score=0, token_score=0,
node_score=0, node_score=0,
local_node_score=0, output_score=0,
is_end=False, is_end=False,
) )
self.root.fail = self.root self.root.fail = self.root
@ -131,6 +130,7 @@ class ContextGraph:
output = None output = None
break break
node.output = output node.output = output
node.output_score += 0 if output is None else output.output_score
queue.append(node) queue.append(node)
def build(self, token_ids: List[List[int]]): def build(self, token_ids: List[List[int]]):
@ -153,14 +153,13 @@ class ContextGraph:
if token not in node.next: if token not in node.next:
self.num_nodes += 1 self.num_nodes += 1
is_end = i == len(tokens) - 1 is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState( node.next[token] = ContextState(
id=self.num_nodes, id=self.num_nodes,
token=token, token=token,
token_score=self.context_score, token_score=self.context_score,
node_score=node.node_score + self.context_score, node_score=node_score,
local_node_score=0 output_score=node_score if is_end else 0,
if is_end
else (node.local_node_score + self.context_score),
is_end=is_end, is_end=is_end,
) )
node = node.next[token] node = node.next[token]
@ -186,8 +185,6 @@ class ContextGraph:
if token in state.next: if token in state.next:
node = state.next[token] node = state.next[token]
score = node.token_score score = node.token_score
if state.is_end:
score += state.node_score
else: else:
# token not matched # token not matched
# We will trace along the fail arc until it matches the token or reaching # We will trace along the fail arc until it matches the token or reaching
@ -202,14 +199,9 @@ class ContextGraph:
node = node.next[token] node = node.next[token]
# The score of the fail path # The score of the fail path
score = node.node_score - state.local_node_score score = node.node_score - state.node_score
assert node is not None assert node is not None
matched_score = 0 return (score + node.output_score, node)
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]: def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize """When reaching the end of the decoded sequence, we need to finalize
@ -227,8 +219,6 @@ class ContextGraph:
""" """
# The score of the fail arc # The score of the fail arc
score = -state.node_score score = -state.node_score
if state.is_end:
score = 0
return (score, self.root) return (score, self.root)
def draw( def draw(
@ -307,10 +297,8 @@ class ContextGraph:
for token, node in current_node.next.items(): for token, node in current_node.next.items():
if node.id not in seen: if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip( output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
"0" label = f"{node.id}/({node_score}, {output_score})"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
if node.is_end: if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr) dot.node(str(node.id), label=label, **final_state_attr)
else: else:
@ -391,6 +379,7 @@ if __name__ == "__main__":
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE" "HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE" "SHED": 6, # "S", "SHE", "HE"
"SHELF": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE" "HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO" "HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S" "DHRHISQ": 4, # "HIS", "S"

View File

@ -493,6 +493,7 @@ def write_error_stats(
test_set_name: str, test_set_name: str,
results: List[Tuple[str, str]], results: List[Tuple[str, str]],
enable_log: bool = True, enable_log: bool = True,
sclite_mode: bool = False,
) -> float: ) -> float:
"""Write statistics based on predicted results and reference transcripts. """Write statistics based on predicted results and reference transcripts.
@ -538,7 +539,7 @@ def write_error_stats(
num_corr = 0 num_corr = 0
ERR = "*" ERR = "*"
for cut_id, ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR) ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
if ref_word == ERR: if ref_word == ERR:
ins[hyp_word] += 1 ins[hyp_word] += 1

View File

@ -15,7 +15,7 @@ graphviz==0.19.1
git+https://github.com/lhotse-speech/lhotse git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11 kaldilm==1.11
kaldialign==0.2 kaldialign==0.7.1
sentencepiece==0.1.96 sentencepiece==0.1.96
tensorboard==2.8.0 tensorboard==2.8.0
typeguard==2.13.3 typeguard==2.13.3