mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
First version using conformer lm for rescoring (not tested)
This commit is contained in:
parent
1ac9bb3fd7
commit
cdd539e55c
@ -21,8 +21,12 @@ import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from conformer_ctc.transformer import (
|
||||
Supervisions,
|
||||
Transformer,
|
||||
encoder_padding_mask,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
|
@ -26,8 +26,9 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer_ctc.conformer import Conformer
|
||||
from conformer_lm.conformer import MaskedLmConformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -37,6 +38,7 @@ from icefall.decode import (
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_conformer_lm,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
@ -94,7 +96,10 @@ def get_parser():
|
||||
is the decoding result.
|
||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
- (6) conformer-lm. In addition to attention-decoder rescoring, it
|
||||
also uses conformer lm for rescoring. See the model in the
|
||||
directory ./conformer_lm
|
||||
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
@ -106,7 +111,8 @@ def get_parser():
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
nbest, nbest-rescoring, attention-decoder, conformer-lm,
|
||||
and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
@ -117,8 +123,8 @@ def get_parser():
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
nbest, nbest-rescoring, attention-decoder, conformer_lm,
|
||||
and nbest-oracle. A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -147,6 +153,35 @@ def get_parser():
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-lm-exp-dir",
|
||||
type=str,
|
||||
default="conformer_lm/exp",
|
||||
help="""The conformer lm exp dir.
|
||||
Used only when method is conformer_lm.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-lm-epoch",
|
||||
type=int,
|
||||
default=19,
|
||||
help="""Used only when method is conformer_lm.
|
||||
It specifies the checkpoint to use for the conformer
|
||||
lm model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-lm-avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Used only when method is conformer_lm.
|
||||
It specifies number of checkpoints to average for
|
||||
the conformer lm model.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -177,6 +212,7 @@ def get_params() -> AttributeDict:
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
masked_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
@ -334,6 +370,7 @@ def decode_one_batch(
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"conformer-lm",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
@ -354,7 +391,7 @@ def decode_one_batch(
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
elif params.method in ("attention-decoder", "conformer-lm"):
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
@ -364,6 +401,7 @@ def decode_one_batch(
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
if params.method == "attention-decoder":
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
@ -374,6 +412,21 @@ def decode_one_batch(
|
||||
eos_id=eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
# It uses:
|
||||
# attention_decoder + conformer_lm
|
||||
best_path_dict = rescore_with_conformer_lm(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
masked_lm_model=masked_lm_model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=0, # TODO(fangjun): pass it as an argument
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
@ -393,6 +446,7 @@ def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
masked_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
@ -449,6 +503,7 @@ def decode_dataset(
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
masked_lm_model=masked_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
@ -584,6 +639,7 @@ def main():
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"conformer-lm",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
@ -607,7 +663,11 @@ def main():
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
if params.method in [
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"conformer-lm",
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
@ -655,6 +715,38 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.method == "conformer-lm":
|
||||
logging.info("Loading conformer lm model")
|
||||
# Note: If the parameters does not match
|
||||
# the one used to save the checkpoint, it will
|
||||
# throw while calling `load_state_dict`.
|
||||
masked_lm_model = MaskedLmConformer(
|
||||
num_classes=num_classes,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
)
|
||||
if params.conformer_lm_avg == 1:
|
||||
load_checkpoint(
|
||||
f"{params.conformer_lm_exp_dir}/epoch-{params.conformer_lm_epoch}.pt", # noqa
|
||||
masked_lm_model,
|
||||
)
|
||||
else:
|
||||
start = params.conformer_lm_epoch - params.conformer_lm_avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.conformer_lm_epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(
|
||||
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
masked_lm_model.to(device)
|
||||
masked_lm_model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device)
|
||||
)
|
||||
else:
|
||||
masked_lm_model = None
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
@ -668,6 +760,7 @@ def main():
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
masked_lm_model=masked_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
|
@ -24,7 +24,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
from conformer_ctc.conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
|
@ -27,7 +27,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from conformer import Conformer
|
||||
from conformer_ctc.conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
|
@ -17,8 +17,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformer import (
|
||||
from conformer_ctc.transformer import (
|
||||
Transformer,
|
||||
add_eos,
|
||||
add_sos,
|
||||
@ -26,6 +25,7 @@ from transformer import (
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
|
@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from conformer_ctc.subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
|
@ -1,693 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import MaskedLmConformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=19,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="attention-decoder",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, the averaged model is saved to
|
||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_lm/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_5000",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"lm_dir": Path("data/lm"),
|
||||
"num_tokens": 5000,
|
||||
"blank_sym": 0,
|
||||
"bos_sym": 1,
|
||||
"eos_sym": 1,
|
||||
# parameters for conformer
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 6,
|
||||
# parameters for decoding
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // params.subsampling_factor,
|
||||
supervisions["num_frames"] // params.subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
for lm_scale in lm_scale_list:
|
||||
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
if params.method == "attention-decoder":
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
assert num_classes == params.num_tokens
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
model = MaskedLmConformer(
|
||||
num_classes=params.num_tokens,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
print("TODO: Implement me!")
|
||||
# [ ] Add an option to use conformer lm for rescoring
|
||||
# [ ] Load conformer_lm only when that options is activated
|
||||
# [ ] Load conformer model
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -20,6 +20,11 @@ from typing import Dict, List, Optional, Union
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lm.rescore import (
|
||||
compute_alignment,
|
||||
make_hyp_to_ref_map,
|
||||
prepare_conformer_lm_inputs,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
@ -224,6 +229,7 @@ class Nbest(object):
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
|
||||
# Each utterance has `num_paths` paths but some of them transduces
|
||||
# to the same word sequence, so we need to remove repeated word
|
||||
@ -889,3 +895,198 @@ def rescore_with_attention_decoder(
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_conformer_lm(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
model: torch.nn.Module,
|
||||
masked_lm_model: torch.nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: Optional[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
blank_id: int,
|
||||
nbest_scale: float = 1.0,
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
masked_lm_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest score is
|
||||
the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
model:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
memory:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `(T, N, C)`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape `(N, T)`.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
nbest_scale:
|
||||
It's the scale applied to `lattice.scores`. A smaller value
|
||||
leads to more unique paths at the risk of missing the correct path.
|
||||
ngram_lm_scale:
|
||||
Optional. It specifies the scale for n-gram LM scores.
|
||||
attention_scale:
|
||||
Optional. It specifies the scale for attention decoder scores.
|
||||
masked_lm_scale:
|
||||
Optional. It specifies the scale for conformer_lm scores.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa has its scores set.
|
||||
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
am_scores = nbest.compute_am_scores()
|
||||
ngram_lm_scores = nbest.compute_lm_scores()
|
||||
|
||||
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||
assert hasattr(nbest.fsa, "tokens")
|
||||
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||
|
||||
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||
|
||||
if memory_key_padding_mask is not None:
|
||||
# The shape of memory_key_padding_mask is (N, T), so we
|
||||
# use axis=0 here.
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
0, path_to_utt_map
|
||||
)
|
||||
else:
|
||||
expanded_memory_key_padding_mask = None
|
||||
|
||||
# remove axis corresponding to states.
|
||||
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||
tokens = tokens.remove_values_leq(0)
|
||||
|
||||
alignment = compute_alignment(tokens, nbest.shape)
|
||||
(
|
||||
masked_src_symbols,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
tgt_weights,
|
||||
) = prepare_conformer_lm_inputs(
|
||||
alignment,
|
||||
bos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=blank_id,
|
||||
unmasked_weight=0.0,
|
||||
)
|
||||
|
||||
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
||||
src_symbols = src_symbols.to(torch.int64)
|
||||
tgt_symbols = tgt_symbols.to(torch.int64)
|
||||
|
||||
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
||||
masked_src_symbols, src_key_padding_mask
|
||||
)
|
||||
|
||||
tgt_nll = masked_lm_model.decoder_nll(
|
||||
masked_lm_memory,
|
||||
masked_lm_pos_emb,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
)
|
||||
|
||||
# nll means negative log-likelihood
|
||||
# ll means log-likelihood
|
||||
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
||||
|
||||
# Note: log-likelihood for those pairs that have identical src/tgt are 0
|
||||
# since their tgt_weights is 0
|
||||
|
||||
# TODO(fangjun): Add documentation about why we do the following
|
||||
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1))
|
||||
tgt_ll_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=None,
|
||||
row_ids=tgt_ll_shape_row_ids,
|
||||
cached_tot_size=tgt_ll_shape_row_ids.numel(),
|
||||
)
|
||||
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
|
||||
|
||||
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
|
||||
masked_lm_scores = ragged_tgt_ll.max()
|
||||
|
||||
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
||||
token_ids = tokens.tolist()
|
||||
|
||||
nll = model.decoder_nll(
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
if masked_lm_scale is None:
|
||||
masked_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
masked_lm_scale_list = [masked_lm_scale]
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for a_scale in attention_scale_list:
|
||||
for m_scale in masked_lm_scale_list:
|
||||
tot_scores = (
|
||||
am_scores.values
|
||||
+ n_scale * ngram_lm_scores.values
|
||||
+ a_scale * attention_scores
|
||||
+ m_scale * masked_lm_scores
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_masked_lm_scale_{m_scale}" # noqa
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
@ -32,6 +32,8 @@ We can generate the following inputs for the conformer LM model from `tokens`:
|
||||
- src
|
||||
- tgt
|
||||
by using `k2.levenshtein_alignment`.
|
||||
|
||||
TODO(fangjun): Add more doc about rescoring with masked conformer-lm.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
@ -39,8 +41,6 @@ from typing import Tuple
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.decode import Nbest
|
||||
|
||||
|
||||
def make_key_padding_mask(lengths: torch.Tensor):
|
||||
"""
|
||||
@ -236,7 +236,7 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
||||
>>> tokens
|
||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
|
||||
>>> make_repeat(tokens)
|
||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ]
|
||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa
|
||||
|
||||
TODO: Add documentation.
|
||||
|
||||
@ -300,6 +300,7 @@ def prepare_conformer_lm_inputs(
|
||||
alignments:
|
||||
It is computed by :func:`compute_alignment`
|
||||
"""
|
||||
device = alignment.device
|
||||
# alignment.arcs.shape has axes [fsa][state][arc]
|
||||
# we remove axis 1, i.e., state, here
|
||||
labels_shape = alignment.arcs.shape().remove_axis(1)
|
||||
@ -337,6 +338,7 @@ def prepare_conformer_lm_inputs(
|
||||
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
|
||||
fill_value=1,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# find unmasked positions
|
||||
@ -359,52 +361,3 @@ def prepare_conformer_lm_inputs(
|
||||
src_key_padding_mask,
|
||||
weight,
|
||||
)
|
||||
|
||||
|
||||
def conformer_lm_rescore(
|
||||
nbest: Nbest,
|
||||
model: torch.nn.Module,
|
||||
bos_id: int,
|
||||
eos_id: int,
|
||||
blank_id: int,
|
||||
unmasked_weight: float = 0.25,
|
||||
# TODO: add other arguments if needed
|
||||
) -> k2.RaggedTensor:
|
||||
"""Rescore an Nbest object with a conformer_lm model.
|
||||
|
||||
Args:
|
||||
nbest:
|
||||
It contains linear FSAs to be rescored.
|
||||
model:
|
||||
A conformer lm model. See "conformer_lm/train.py"
|
||||
|
||||
Returns:
|
||||
Return a ragged tensor containing scores for each path
|
||||
contained in the nbest. Its shape equals to `nbest.shape`.
|
||||
"""
|
||||
assert hasattr(nbest.fsa, "tokens")
|
||||
utt_path_shape = nbest.shape
|
||||
# nbest.fsa.arcs.shape() has axes [path][state][arc]
|
||||
# We remove the state axis here
|
||||
path_token_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
|
||||
path_token = k2.RaggedTensor(path_token_shape, nbest.fsa.tokens)
|
||||
path_token = path_token.remove_values_leq(0)
|
||||
|
||||
alignment = compute_alignment(path_token, utt_path_shape)
|
||||
(
|
||||
masked_src,
|
||||
src,
|
||||
tgt,
|
||||
src_key_padding_mask,
|
||||
weight,
|
||||
) = prepare_conformer_lm_inputs(
|
||||
alignment,
|
||||
bos_id=bos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=blank_id,
|
||||
unmasked_weight=unmasked_weight,
|
||||
)
|
||||
return masked_src, src, tgt, src_key_padding_mask, weight
|
||||
# TODO: pass masked_src, src, tgt, src_key_padding_mask, and weight
|
||||
# to the given model
|
||||
|
@ -17,12 +17,10 @@
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.decode import Nbest
|
||||
from icefall.lm.rescore import (
|
||||
add_bos,
|
||||
add_eos,
|
||||
compute_alignment,
|
||||
conformer_lm_rescore,
|
||||
make_hyp_to_ref_map,
|
||||
make_repeat,
|
||||
make_repeat_map,
|
||||
@ -45,6 +43,7 @@ def test_add_eos():
|
||||
expected = k2.RaggedTensor(
|
||||
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
||||
)
|
||||
assert str(ragged_eos) == str(expected)
|
||||
|
||||
|
||||
def test_pad():
|
||||
@ -71,7 +70,7 @@ def test_make_hyp_to_ref_map():
|
||||
repeat_map = make_hyp_to_ref_map(row_splits)
|
||||
# fmt: off
|
||||
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
|
||||
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map)
|
||||
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa
|
||||
# fmt: on
|
||||
assert torch.all(torch.eq(repeat_map, expected))
|
||||
|
||||
@ -82,8 +81,8 @@ def test_make_repeat_map():
|
||||
repeat_map = make_repeat_map(row_splits)
|
||||
# fmt: off
|
||||
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
|
||||
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6,
|
||||
3, 4, 5, 6]).to(repeat_map)
|
||||
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa
|
||||
3, 4, 5, 6]).to(repeat_map) # noqa
|
||||
# fmt: on
|
||||
assert torch.all(torch.eq(repeat_map, expected))
|
||||
|
||||
@ -132,27 +131,6 @@ def test_compute_alignment():
|
||||
# print("weight", weight)
|
||||
|
||||
|
||||
def test_conformer_lm_rescore():
|
||||
path00 = k2.linear_fsa([1, 2, 0, 3, 0, 5])
|
||||
path01 = k2.linear_fsa([1, 0, 5, 0])
|
||||
path10 = k2.linear_fsa([9, 8, 0, 3, 0, 2])
|
||||
path11 = k2.linear_fsa([9, 8, 0, 0, 3, 2])
|
||||
path12 = k2.linear_fsa([9, 0, 8, 4, 0, 2, 3])
|
||||
|
||||
fsa = k2.Fsa.from_fsas([path00, path01, path10, path11, path12])
|
||||
fsa.tokens = fsa.labels.clone()
|
||||
shape = k2.RaggedShape("[[x x] [x x x]]")
|
||||
nbest = Nbest(fsa, shape)
|
||||
masked_src, src, tgt, src_key_padding_mask, weight = conformer_lm_rescore(
|
||||
nbest, model=None, bos_id=10, eos_id=20, blank_id=0
|
||||
)
|
||||
print("masked src", masked_src)
|
||||
print("src", src)
|
||||
print("tgt", tgt)
|
||||
print("src_key_padding_mask", src_key_padding_mask)
|
||||
print("weight", weight)
|
||||
|
||||
|
||||
def main():
|
||||
test_add_bos()
|
||||
test_add_eos()
|
||||
@ -161,7 +139,6 @@ def main():
|
||||
test_make_hyp_to_ref_map()
|
||||
test_make_repeat()
|
||||
test_compute_alignment()
|
||||
test_conformer_lm_rescore()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user