First version using conformer lm for rescoring (not tested)

This commit is contained in:
Fangjun Kuang 2021-11-03 20:59:54 +08:00
parent 1ac9bb3fd7
commit cdd539e55c
10 changed files with 331 additions and 796 deletions

View File

@ -21,8 +21,12 @@ import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from conformer_ctc.transformer import (
Supervisions,
Transformer,
encoder_padding_mask,
)
from torch import Tensor, nn from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer): class Conformer(Transformer):

View File

@ -26,8 +26,9 @@ import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer_ctc.conformer import Conformer
from conformer_lm.conformer import MaskedLmConformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -37,6 +38,7 @@ from icefall.decode import (
nbest_oracle, nbest_oracle,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_conformer_lm,
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
@ -94,7 +96,10 @@ def get_parser():
is the decoding result. is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored - (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result. lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best - (6) conformer-lm. In addition to attention-decoder rescoring, it
also uses conformer lm for rescoring. See the model in the
directory ./conformer_lm
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best rescoring method can achieve. Useful for debugging n-best
rescoring method. rescoring method.
""", """,
@ -106,7 +111,8 @@ def get_parser():
default=100, default=100,
help="""Number of paths for n-best based decoding method. help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values: Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle nbest, nbest-rescoring, attention-decoder, conformer-lm,
and nbest-oracle
""", """,
) )
@ -117,8 +123,8 @@ def get_parser():
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring. It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values: Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle nbest, nbest-rescoring, attention-decoder, conformer_lm,
A smaller value results in more unique paths. and nbest-oracle. A smaller value results in more unique paths.
""", """,
) )
@ -147,6 +153,35 @@ def get_parser():
help="The lang dir", help="The lang dir",
) )
parser.add_argument(
"--conformer-lm-exp-dir",
type=str,
default="conformer_lm/exp",
help="""The conformer lm exp dir.
Used only when method is conformer_lm.
""",
)
parser.add_argument(
"--conformer-lm-epoch",
type=int,
default=19,
help="""Used only when method is conformer_lm.
It specifies the checkpoint to use for the conformer
lm model.
""",
)
parser.add_argument(
"--conformer-lm-avg",
type=int,
default=1,
help="""Used only when method is conformer_lm.
It specifies number of checkpoints to average for
the conformer lm model.
""",
)
return parser return parser
@ -177,6 +212,7 @@ def get_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
masked_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -334,6 +370,7 @@ def decode_one_batch(
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
"conformer-lm",
] ]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
@ -354,7 +391,7 @@ def decode_one_batch(
G_with_epsilon_loops=G, G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
) )
elif params.method == "attention-decoder": elif params.method in ("attention-decoder", "conformer-lm"):
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice( rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, lattice=lattice,
@ -364,16 +401,32 @@ def decode_one_batch(
# TODO: pass `lattice` instead of `rescored_lattice` to # TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder` # `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder( if params.method == "attention-decoder":
lattice=rescored_lattice, best_path_dict = rescore_with_attention_decoder(
num_paths=params.num_paths, lattice=rescored_lattice,
model=model, num_paths=params.num_paths,
memory=memory, model=model,
memory_key_padding_mask=memory_key_padding_mask, memory=memory,
sos_id=sos_id, memory_key_padding_mask=memory_key_padding_mask,
eos_id=eos_id, sos_id=sos_id,
nbest_scale=params.nbest_scale, eos_id=eos_id,
) nbest_scale=params.nbest_scale,
)
else:
# It uses:
# attention_decoder + conformer_lm
best_path_dict = rescore_with_conformer_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
masked_lm_model=masked_lm_model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0, # TODO(fangjun): pass it as an argument
nbest_scale=params.nbest_scale,
)
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -393,6 +446,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
masked_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -449,6 +503,7 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
masked_lm_model=masked_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,
@ -584,6 +639,7 @@ def main():
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
"conformer-lm",
): ):
if not (params.lm_dir / "G_4_gram.pt").is_file(): if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt") logging.info("Loading G_4_gram.fst.txt")
@ -607,7 +663,11 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"conformer-lm",
]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
G = k2.add_epsilon_self_loops(G) G = k2.add_epsilon_self_loops(G)
@ -655,6 +715,38 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
if params.method == "conformer-lm":
logging.info("Loading conformer lm model")
# Note: If the parameters does not match
# the one used to save the checkpoint, it will
# throw while calling `load_state_dict`.
masked_lm_model = MaskedLmConformer(
num_classes=num_classes,
d_model=params.attention_dim,
nhead=params.nhead,
num_decoder_layers=params.num_decoder_layers,
)
if params.conformer_lm_avg == 1:
load_checkpoint(
f"{params.conformer_lm_exp_dir}/epoch-{params.conformer_lm_epoch}.pt", # noqa
masked_lm_model,
)
else:
start = params.conformer_lm_epoch - params.conformer_lm_avg + 1
filenames = []
for i in range(start, params.conformer_lm_epoch + 1):
if start >= 0:
filenames.append(
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
)
logging.info(f"averaging {filenames}")
masked_lm_model.to(device)
masked_lm_model.load_state_dict(
average_checkpoints(filenames, device=device)
)
else:
masked_lm_model = None
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only. # CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip # If you want to skip test-clean, you have to skip
@ -668,6 +760,7 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
masked_lm_model=masked_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,

View File

@ -24,7 +24,7 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from conformer import Conformer from conformer_ctc.conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon

View File

@ -27,7 +27,7 @@ import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from conformer_ctc.conformer import Conformer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import ( from icefall.decode import (

View File

@ -17,8 +17,7 @@
import torch import torch
from torch.nn.utils.rnn import pad_sequence from conformer_ctc.transformer import (
from transformer import (
Transformer, Transformer,
add_eos, add_eos,
add_sos, add_sos,
@ -26,6 +25,7 @@ from transformer import (
encoder_padding_mask, encoder_padding_mask,
generate_square_subsequent_mask, generate_square_subsequent_mask,
) )
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask(): def test_encoder_padding_mask():

View File

@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling from conformer_ctc.subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed. # Note: TorchScript requires Dict/List/etc. to be fully typed.

View File

@ -1,693 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import MaskedLmConformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_env_info,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=19,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_lm/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="The lang dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"lm_dir": Path("data/lm"),
"num_tokens": 5000,
"blank_sym": 0,
"bos_sym": 1,
"eos_sym": 1,
# parameters for conformer
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
assert num_classes == params.num_tokens
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = MaskedLmConformer(
num_classes=params.num_tokens,
d_model=params.attention_dim,
nhead=params.nhead,
num_decoder_layers=params.num_decoder_layers,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
print("TODO: Implement me!")
# [ ] Add an option to use conformer lm for rescoring
# [ ] Load conformer_lm only when that options is activated
# [ ] Load conformer model
return
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -20,6 +20,11 @@ from typing import Dict, List, Optional, Union
import k2 import k2
import torch import torch
from icefall.lm.rescore import (
compute_alignment,
make_hyp_to_ref_map,
prepare_conformer_lm_inputs,
)
from icefall.utils import get_texts from icefall.utils import get_texts
@ -224,6 +229,7 @@ class Nbest(object):
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0)
# Each utterance has `num_paths` paths but some of them transduces # Each utterance has `num_paths` paths but some of them transduces
# to the same word sequence, so we need to remove repeated word # to the same word sequence, so we need to remove repeated word
@ -889,3 +895,198 @@ def rescore_with_attention_decoder(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path ans[key] = best_path
return ans return ans
def rescore_with_conformer_lm(
lattice: k2.Fsa,
num_paths: int,
model: torch.nn.Module,
masked_lm_model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
blank_id: int,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
masked_lm_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
memory_key_padding_mask:
The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
masked_lm_scale:
Optional. It specifies the scale for conformer_lm scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
# The shape of memory_key_padding_mask is (N, T), so we
# use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
alignment = compute_alignment(tokens, nbest.shape)
(
masked_src_symbols,
src_symbols,
tgt_symbols,
src_key_padding_mask,
tgt_weights,
) = prepare_conformer_lm_inputs(
alignment,
bos_id=sos_id,
eos_id=eos_id,
blank_id=blank_id,
unmasked_weight=0.0,
)
masked_src_symbols = masked_src_symbols.to(torch.int64)
src_symbols = src_symbols.to(torch.int64)
tgt_symbols = tgt_symbols.to(torch.int64)
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
masked_src_symbols, src_key_padding_mask
)
tgt_nll = masked_lm_model.decoder_nll(
masked_lm_memory,
masked_lm_pos_emb,
src_symbols,
tgt_symbols,
src_key_padding_mask,
)
# nll means negative log-likelihood
# ll means log-likelihood
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
# Note: log-likelihood for those pairs that have identical src/tgt are 0
# since their tgt_weights is 0
# TODO(fangjun): Add documentation about why we do the following
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1))
tgt_ll_shape = k2.ragged.create_ragged_shape2(
row_splits=None,
row_ids=tgt_ll_shape_row_ids,
cached_tot_size=tgt_ll_shape_row_ids.numel(),
)
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
masked_lm_scores = ragged_tgt_ll.max()
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
token_ids = tokens.tolist()
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
if masked_lm_scale is None:
masked_lm_scale_list = [0.01, 0.05, 0.08]
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
masked_lm_scale_list = [masked_lm_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
for m_scale in masked_lm_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
+ m_scale * masked_lm_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_masked_lm_scale_{m_scale}" # noqa
ans[key] = best_path
return ans

View File

@ -32,6 +32,8 @@ We can generate the following inputs for the conformer LM model from `tokens`:
- src - src
- tgt - tgt
by using `k2.levenshtein_alignment`. by using `k2.levenshtein_alignment`.
TODO(fangjun): Add more doc about rescoring with masked conformer-lm.
""" """
from typing import Tuple from typing import Tuple
@ -39,8 +41,6 @@ from typing import Tuple
import k2 import k2
import torch import torch
from icefall.decode import Nbest
def make_key_padding_mask(lengths: torch.Tensor): def make_key_padding_mask(lengths: torch.Tensor):
""" """
@ -236,7 +236,7 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
>>> tokens >>> tokens
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ] [ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
>>> make_repeat(tokens) >>> make_repeat(tokens)
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] [ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa
TODO: Add documentation. TODO: Add documentation.
@ -300,6 +300,7 @@ def prepare_conformer_lm_inputs(
alignments: alignments:
It is computed by :func:`compute_alignment` It is computed by :func:`compute_alignment`
""" """
device = alignment.device
# alignment.arcs.shape has axes [fsa][state][arc] # alignment.arcs.shape has axes [fsa][state][arc]
# we remove axis 1, i.e., state, here # we remove axis 1, i.e., state, here
labels_shape = alignment.arcs.shape().remove_axis(1) labels_shape = alignment.arcs.shape().remove_axis(1)
@ -337,6 +338,7 @@ def prepare_conformer_lm_inputs(
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1), (tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
fill_value=1, fill_value=1,
dtype=torch.float32, dtype=torch.float32,
device=device,
) )
# find unmasked positions # find unmasked positions
@ -359,52 +361,3 @@ def prepare_conformer_lm_inputs(
src_key_padding_mask, src_key_padding_mask,
weight, weight,
) )
def conformer_lm_rescore(
nbest: Nbest,
model: torch.nn.Module,
bos_id: int,
eos_id: int,
blank_id: int,
unmasked_weight: float = 0.25,
# TODO: add other arguments if needed
) -> k2.RaggedTensor:
"""Rescore an Nbest object with a conformer_lm model.
Args:
nbest:
It contains linear FSAs to be rescored.
model:
A conformer lm model. See "conformer_lm/train.py"
Returns:
Return a ragged tensor containing scores for each path
contained in the nbest. Its shape equals to `nbest.shape`.
"""
assert hasattr(nbest.fsa, "tokens")
utt_path_shape = nbest.shape
# nbest.fsa.arcs.shape() has axes [path][state][arc]
# We remove the state axis here
path_token_shape = nbest.fsa.arcs.shape().remove_axis(1)
path_token = k2.RaggedTensor(path_token_shape, nbest.fsa.tokens)
path_token = path_token.remove_values_leq(0)
alignment = compute_alignment(path_token, utt_path_shape)
(
masked_src,
src,
tgt,
src_key_padding_mask,
weight,
) = prepare_conformer_lm_inputs(
alignment,
bos_id=bos_id,
eos_id=eos_id,
blank_id=blank_id,
unmasked_weight=unmasked_weight,
)
return masked_src, src, tgt, src_key_padding_mask, weight
# TODO: pass masked_src, src, tgt, src_key_padding_mask, and weight
# to the given model

View File

@ -17,12 +17,10 @@
import k2 import k2
import torch import torch
from icefall.decode import Nbest
from icefall.lm.rescore import ( from icefall.lm.rescore import (
add_bos, add_bos,
add_eos, add_eos,
compute_alignment, compute_alignment,
conformer_lm_rescore,
make_hyp_to_ref_map, make_hyp_to_ref_map,
make_repeat, make_repeat,
make_repeat_map, make_repeat_map,
@ -45,6 +43,7 @@ def test_add_eos():
expected = k2.RaggedTensor( expected = k2.RaggedTensor(
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]] [[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
) )
assert str(ragged_eos) == str(expected)
def test_pad(): def test_pad():
@ -71,7 +70,7 @@ def test_make_hyp_to_ref_map():
repeat_map = make_hyp_to_ref_map(row_splits) repeat_map = make_hyp_to_ref_map(row_splits)
# fmt: off # fmt: off
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa
# fmt: on # fmt: on
assert torch.all(torch.eq(repeat_map, expected)) assert torch.all(torch.eq(repeat_map, expected))
@ -82,8 +81,8 @@ def test_make_repeat_map():
repeat_map = make_repeat_map(row_splits) repeat_map = make_repeat_map(row_splits)
# fmt: off # fmt: off
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa
3, 4, 5, 6]).to(repeat_map) 3, 4, 5, 6]).to(repeat_map) # noqa
# fmt: on # fmt: on
assert torch.all(torch.eq(repeat_map, expected)) assert torch.all(torch.eq(repeat_map, expected))
@ -132,27 +131,6 @@ def test_compute_alignment():
# print("weight", weight) # print("weight", weight)
def test_conformer_lm_rescore():
path00 = k2.linear_fsa([1, 2, 0, 3, 0, 5])
path01 = k2.linear_fsa([1, 0, 5, 0])
path10 = k2.linear_fsa([9, 8, 0, 3, 0, 2])
path11 = k2.linear_fsa([9, 8, 0, 0, 3, 2])
path12 = k2.linear_fsa([9, 0, 8, 4, 0, 2, 3])
fsa = k2.Fsa.from_fsas([path00, path01, path10, path11, path12])
fsa.tokens = fsa.labels.clone()
shape = k2.RaggedShape("[[x x] [x x x]]")
nbest = Nbest(fsa, shape)
masked_src, src, tgt, src_key_padding_mask, weight = conformer_lm_rescore(
nbest, model=None, bos_id=10, eos_id=20, blank_id=0
)
print("masked src", masked_src)
print("src", src)
print("tgt", tgt)
print("src_key_padding_mask", src_key_padding_mask)
print("weight", weight)
def main(): def main():
test_add_bos() test_add_bos()
test_add_eos() test_add_eos()
@ -161,7 +139,6 @@ def main():
test_make_hyp_to_ref_map() test_make_hyp_to_ref_map()
test_make_repeat() test_make_repeat()
test_compute_alignment() test_compute_alignment()
test_conformer_lm_rescore()
if __name__ == "__main__": if __name__ == "__main__":