mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
First version using conformer lm for rescoring (not tested)
This commit is contained in:
parent
1ac9bb3fd7
commit
cdd539e55c
@ -21,8 +21,12 @@ import warnings
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from conformer_ctc.transformer import (
|
||||||
|
Supervisions,
|
||||||
|
Transformer,
|
||||||
|
encoder_padding_mask,
|
||||||
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
|
@ -26,8 +26,9 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer_ctc.conformer import Conformer
|
||||||
|
from conformer_lm.conformer import MaskedLmConformer
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -37,6 +38,7 @@ from icefall.decode import (
|
|||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
|
rescore_with_conformer_lm,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
@ -94,7 +96,10 @@ def get_parser():
|
|||||||
is the decoding result.
|
is the decoding result.
|
||||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||||
lattice, the path with the highest score is the decoding result.
|
lattice, the path with the highest score is the decoding result.
|
||||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
- (6) conformer-lm. In addition to attention-decoder rescoring, it
|
||||||
|
also uses conformer lm for rescoring. See the model in the
|
||||||
|
directory ./conformer_lm
|
||||||
|
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
rescoring method.
|
rescoring method.
|
||||||
""",
|
""",
|
||||||
@ -106,7 +111,8 @@ def get_parser():
|
|||||||
default=100,
|
default=100,
|
||||||
help="""Number of paths for n-best based decoding method.
|
help="""Number of paths for n-best based decoding method.
|
||||||
Used only when "method" is one of the following values:
|
Used only when "method" is one of the following values:
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
nbest, nbest-rescoring, attention-decoder, conformer-lm,
|
||||||
|
and nbest-oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -117,8 +123,8 @@ def get_parser():
|
|||||||
help="""The scale to be applied to `lattice.scores`.
|
help="""The scale to be applied to `lattice.scores`.
|
||||||
It's needed if you use any kinds of n-best based rescoring.
|
It's needed if you use any kinds of n-best based rescoring.
|
||||||
Used only when "method" is one of the following values:
|
Used only when "method" is one of the following values:
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
nbest, nbest-rescoring, attention-decoder, conformer_lm,
|
||||||
A smaller value results in more unique paths.
|
and nbest-oracle. A smaller value results in more unique paths.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,6 +153,35 @@ def get_parser():
|
|||||||
help="The lang dir",
|
help="The lang dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--conformer-lm-exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_lm/exp",
|
||||||
|
help="""The conformer lm exp dir.
|
||||||
|
Used only when method is conformer_lm.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--conformer-lm-epoch",
|
||||||
|
type=int,
|
||||||
|
default=19,
|
||||||
|
help="""Used only when method is conformer_lm.
|
||||||
|
It specifies the checkpoint to use for the conformer
|
||||||
|
lm model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--conformer-lm-avg",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Used only when method is conformer_lm.
|
||||||
|
It specifies number of checkpoints to average for
|
||||||
|
the conformer lm model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -177,6 +212,7 @@ def get_params() -> AttributeDict:
|
|||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
masked_lm_model: Optional[nn.Module],
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: Optional[k2.Fsa],
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
@ -334,6 +370,7 @@ def decode_one_batch(
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"conformer-lm",
|
||||||
]
|
]
|
||||||
|
|
||||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
@ -354,7 +391,7 @@ def decode_one_batch(
|
|||||||
G_with_epsilon_loops=G,
|
G_with_epsilon_loops=G,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
)
|
)
|
||||||
elif params.method == "attention-decoder":
|
elif params.method in ("attention-decoder", "conformer-lm"):
|
||||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
rescored_lattice = rescore_with_whole_lattice(
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
@ -364,16 +401,32 @@ def decode_one_batch(
|
|||||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||||
# `rescore_with_attention_decoder`
|
# `rescore_with_attention_decoder`
|
||||||
|
|
||||||
best_path_dict = rescore_with_attention_decoder(
|
if params.method == "attention-decoder":
|
||||||
lattice=rescored_lattice,
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
num_paths=params.num_paths,
|
lattice=rescored_lattice,
|
||||||
model=model,
|
num_paths=params.num_paths,
|
||||||
memory=memory,
|
model=model,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory=memory,
|
||||||
sos_id=sos_id,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
eos_id=eos_id,
|
sos_id=sos_id,
|
||||||
nbest_scale=params.nbest_scale,
|
eos_id=eos_id,
|
||||||
)
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# It uses:
|
||||||
|
# attention_decoder + conformer_lm
|
||||||
|
best_path_dict = rescore_with_conformer_lm(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
masked_lm_model=masked_lm_model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=0, # TODO(fangjun): pass it as an argument
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
@ -393,6 +446,7 @@ def decode_dataset(
|
|||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
masked_lm_model: Optional[nn.Module],
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: Optional[k2.Fsa],
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
@ -449,6 +503,7 @@ def decode_dataset(
|
|||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
masked_lm_model=masked_lm_model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
H=H,
|
H=H,
|
||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
@ -584,6 +639,7 @@ def main():
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"conformer-lm",
|
||||||
):
|
):
|
||||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||||
logging.info("Loading G_4_gram.fst.txt")
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
@ -607,7 +663,11 @@ def main():
|
|||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in [
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
"conformer-lm",
|
||||||
|
]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
G = k2.add_epsilon_self_loops(G)
|
G = k2.add_epsilon_self_loops(G)
|
||||||
@ -655,6 +715,38 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
if params.method == "conformer-lm":
|
||||||
|
logging.info("Loading conformer lm model")
|
||||||
|
# Note: If the parameters does not match
|
||||||
|
# the one used to save the checkpoint, it will
|
||||||
|
# throw while calling `load_state_dict`.
|
||||||
|
masked_lm_model = MaskedLmConformer(
|
||||||
|
num_classes=num_classes,
|
||||||
|
d_model=params.attention_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
if params.conformer_lm_avg == 1:
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.conformer_lm_exp_dir}/epoch-{params.conformer_lm_epoch}.pt", # noqa
|
||||||
|
masked_lm_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
start = params.conformer_lm_epoch - params.conformer_lm_avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.conformer_lm_epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(
|
||||||
|
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
masked_lm_model.to(device)
|
||||||
|
masked_lm_model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
masked_lm_model = None
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
# CAUTION: `test_sets` is for displaying only.
|
||||||
# If you want to skip test-clean, you have to skip
|
# If you want to skip test-clean, you have to skip
|
||||||
@ -668,6 +760,7 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
masked_lm_model=masked_lm_model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
H=H,
|
H=H,
|
||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
|
@ -24,7 +24,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer_ctc.conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
@ -27,7 +27,7 @@ import kaldifeat
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from conformer import Conformer
|
from conformer_ctc.conformer import Conformer
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
|
@ -17,8 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from conformer_ctc.transformer import (
|
||||||
from transformer import (
|
|
||||||
Transformer,
|
Transformer,
|
||||||
add_eos,
|
add_eos,
|
||||||
add_sos,
|
add_sos,
|
||||||
@ -26,6 +25,7 @@ from transformer import (
|
|||||||
encoder_padding_mask,
|
encoder_padding_mask,
|
||||||
generate_square_subsequent_mask,
|
generate_square_subsequent_mask,
|
||||||
)
|
)
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
def test_encoder_padding_mask():
|
def test_encoder_padding_mask():
|
||||||
|
@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from conformer_ctc.subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
|
@ -1,693 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
|
||||||
from conformer import MaskedLmConformer
|
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
|
||||||
from icefall.decode import (
|
|
||||||
get_lattice,
|
|
||||||
nbest_decoding,
|
|
||||||
nbest_oracle,
|
|
||||||
one_best_decoding,
|
|
||||||
rescore_with_attention_decoder,
|
|
||||||
rescore_with_n_best_list,
|
|
||||||
rescore_with_whole_lattice,
|
|
||||||
)
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
get_env_info,
|
|
||||||
get_texts,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=19,
|
|
||||||
help="It specifies the checkpoint to use for decoding."
|
|
||||||
"Note: Epoch counts from 0.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch'. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--method",
|
|
||||||
type=str,
|
|
||||||
default="attention-decoder",
|
|
||||||
help="""Decoding method.
|
|
||||||
Supported values are:
|
|
||||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
|
||||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
|
||||||
It needs neither a lexicon nor an n-gram LM.
|
|
||||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
|
||||||
decoding result.
|
|
||||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
|
||||||
with the highest score is the decoding result.
|
|
||||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
|
||||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
|
||||||
the highest score is the decoding result.
|
|
||||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
|
||||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
|
||||||
is the decoding result.
|
|
||||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
|
||||||
lattice, the path with the highest score is the decoding result.
|
|
||||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
|
||||||
rescoring method.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-paths",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="""Number of paths for n-best based decoding method.
|
|
||||||
Used only when "method" is one of the following values:
|
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--nbest-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="""The scale to be applied to `lattice.scores`.
|
|
||||||
It's needed if you use any kinds of n-best based rescoring.
|
|
||||||
Used only when "method" is one of the following values:
|
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
|
||||||
A smaller value results in more unique paths.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--export",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""When enabled, the averaged model is saved to
|
|
||||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
|
||||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
|
||||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="conformer_lm/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_bpe_5000",
|
|
||||||
help="The lang dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
"lm_dir": Path("data/lm"),
|
|
||||||
"num_tokens": 5000,
|
|
||||||
"blank_sym": 0,
|
|
||||||
"bos_sym": 1,
|
|
||||||
"eos_sym": 1,
|
|
||||||
# parameters for conformer
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"num_decoder_layers": 6,
|
|
||||||
# parameters for decoding
|
|
||||||
"search_beam": 20,
|
|
||||||
"output_beam": 8,
|
|
||||||
"min_active_states": 30,
|
|
||||||
"max_active_states": 10000,
|
|
||||||
"use_double_scores": True,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: nn.Module,
|
|
||||||
HLG: Optional[k2.Fsa],
|
|
||||||
H: Optional[k2.Fsa],
|
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
|
||||||
batch: dict,
|
|
||||||
word_table: k2.SymbolTable,
|
|
||||||
sos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
G: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, List[List[str]]]:
|
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
|
||||||
following format:
|
|
||||||
|
|
||||||
- key: It indicates the setting used for decoding. For example,
|
|
||||||
if no rescoring is used, the key is the string `no_rescore`.
|
|
||||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
|
||||||
where `xxx` is the value of `lm_scale`. An example key is
|
|
||||||
`lm_scale_0.7`
|
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
|
||||||
utterance in the given batch.
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
It's the return value of :func:`get_params`.
|
|
||||||
|
|
||||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
|
||||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
|
||||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
|
||||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
|
||||||
rescoring.
|
|
||||||
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
HLG:
|
|
||||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
bpe_model:
|
|
||||||
The BPE model. Used only when params.method is ctc-decoding.
|
|
||||||
batch:
|
|
||||||
It is the return value from iterating
|
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
||||||
for the format of the `batch`.
|
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID of the SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID of the EOS.
|
|
||||||
G:
|
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
|
||||||
is a 3-gram LM, while this G is a 4-gram LM.
|
|
||||||
Returns:
|
|
||||||
Return the decoding result. See above description for the format of
|
|
||||||
the returned dict.
|
|
||||||
"""
|
|
||||||
if HLG is not None:
|
|
||||||
device = HLG.device
|
|
||||||
else:
|
|
||||||
device = H.device
|
|
||||||
feature = batch["inputs"]
|
|
||||||
assert feature.ndim == 3
|
|
||||||
feature = feature.to(device)
|
|
||||||
# at entry, feature is (N, T, C)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
|
|
||||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
|
||||||
# nnet_output is (N, T, C)
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
|
||||||
(
|
|
||||||
supervisions["sequence_idx"],
|
|
||||||
supervisions["start_frame"] // params.subsampling_factor,
|
|
||||||
supervisions["num_frames"] // params.subsampling_factor,
|
|
||||||
),
|
|
||||||
1,
|
|
||||||
).to(torch.int32)
|
|
||||||
|
|
||||||
if H is None:
|
|
||||||
assert HLG is not None
|
|
||||||
decoding_graph = HLG
|
|
||||||
else:
|
|
||||||
assert HLG is None
|
|
||||||
assert bpe_model is not None
|
|
||||||
decoding_graph = H
|
|
||||||
|
|
||||||
lattice = get_lattice(
|
|
||||||
nnet_output=nnet_output,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
supervision_segments=supervision_segments,
|
|
||||||
search_beam=params.search_beam,
|
|
||||||
output_beam=params.output_beam,
|
|
||||||
min_active_states=params.min_active_states,
|
|
||||||
max_active_states=params.max_active_states,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.method == "ctc-decoding":
|
|
||||||
best_path = one_best_decoding(
|
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
|
||||||
)
|
|
||||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
|
||||||
# since we are using H, not HLG here.
|
|
||||||
#
|
|
||||||
# token_ids is a lit-of-list of IDs
|
|
||||||
token_ids = get_texts(best_path)
|
|
||||||
|
|
||||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
|
||||||
hyps = bpe_model.decode(token_ids)
|
|
||||||
|
|
||||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
|
||||||
hyps = [s.split() for s in hyps]
|
|
||||||
key = "ctc-decoding"
|
|
||||||
return {key: hyps}
|
|
||||||
|
|
||||||
if params.method == "nbest-oracle":
|
|
||||||
# Note: You can also pass rescored lattices to it.
|
|
||||||
# We choose the HLG decoded lattice for speed reasons
|
|
||||||
# as HLG decoding is faster and the oracle WER
|
|
||||||
# is only slightly worse than that of rescored lattices.
|
|
||||||
best_path = nbest_oracle(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
ref_texts=supervisions["text"],
|
|
||||||
word_table=word_table,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
oov="<UNK>",
|
|
||||||
)
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
||||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
|
||||||
return {key: hyps}
|
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
|
||||||
if params.method == "1best":
|
|
||||||
best_path = one_best_decoding(
|
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
|
||||||
)
|
|
||||||
key = "no_rescore"
|
|
||||||
else:
|
|
||||||
best_path = nbest_decoding(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
||||||
return {key: hyps}
|
|
||||||
|
|
||||||
assert params.method in [
|
|
||||||
"nbest-rescoring",
|
|
||||||
"whole-lattice-rescoring",
|
|
||||||
"attention-decoder",
|
|
||||||
]
|
|
||||||
|
|
||||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
|
||||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
|
||||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
|
||||||
|
|
||||||
if params.method == "nbest-rescoring":
|
|
||||||
best_path_dict = rescore_with_n_best_list(
|
|
||||||
lattice=lattice,
|
|
||||||
G=G,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
lm_scale_list=lm_scale_list,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
elif params.method == "whole-lattice-rescoring":
|
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
|
||||||
lattice=lattice,
|
|
||||||
G_with_epsilon_loops=G,
|
|
||||||
lm_scale_list=lm_scale_list,
|
|
||||||
)
|
|
||||||
elif params.method == "attention-decoder":
|
|
||||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
|
||||||
rescored_lattice = rescore_with_whole_lattice(
|
|
||||||
lattice=lattice,
|
|
||||||
G_with_epsilon_loops=G,
|
|
||||||
lm_scale_list=None,
|
|
||||||
)
|
|
||||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
|
||||||
# `rescore_with_attention_decoder`
|
|
||||||
|
|
||||||
best_path_dict = rescore_with_attention_decoder(
|
|
||||||
lattice=rescored_lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
model=model,
|
|
||||||
memory=memory,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
|
||||||
|
|
||||||
ans = dict()
|
|
||||||
if best_path_dict is not None:
|
|
||||||
for lm_scale_str, best_path in best_path_dict.items():
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
||||||
ans[lm_scale_str] = hyps
|
|
||||||
else:
|
|
||||||
for lm_scale in lm_scale_list:
|
|
||||||
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
|
||||||
dl: torch.utils.data.DataLoader,
|
|
||||||
params: AttributeDict,
|
|
||||||
model: nn.Module,
|
|
||||||
HLG: Optional[k2.Fsa],
|
|
||||||
H: Optional[k2.Fsa],
|
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
|
||||||
word_table: k2.SymbolTable,
|
|
||||||
sos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
G: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
|
||||||
"""Decode dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dl:
|
|
||||||
PyTorch's dataloader containing the dataset to decode.
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
HLG:
|
|
||||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
bpe_model:
|
|
||||||
The BPE model. Used only when params.method is ctc-decoding.
|
|
||||||
word_table:
|
|
||||||
It is the word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID for SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID for EOS.
|
|
||||||
G:
|
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
|
||||||
is a 3-gram LM, while this G is a 4-gram LM.
|
|
||||||
Returns:
|
|
||||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
|
||||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
|
||||||
The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
|
|
||||||
num_cuts = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
num_batches = len(dl)
|
|
||||||
except TypeError:
|
|
||||||
num_batches = "?"
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
|
||||||
texts = batch["supervisions"]["text"]
|
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
HLG=HLG,
|
|
||||||
H=H,
|
|
||||||
bpe_model=bpe_model,
|
|
||||||
batch=batch,
|
|
||||||
word_table=word_table,
|
|
||||||
G=G,
|
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
|
||||||
this_batch = []
|
|
||||||
assert len(hyps) == len(texts)
|
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
|
||||||
ref_words = ref_text.split()
|
|
||||||
this_batch.append((ref_words, hyp_words))
|
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
|
||||||
|
|
||||||
num_cuts += len(batch["supervisions"]["text"])
|
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
|
||||||
params: AttributeDict,
|
|
||||||
test_set_name: str,
|
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
|
||||||
):
|
|
||||||
if params.method == "attention-decoder":
|
|
||||||
# Set it to False since there are too many logs.
|
|
||||||
enable_log = False
|
|
||||||
else:
|
|
||||||
enable_log = True
|
|
||||||
test_set_wers = dict()
|
|
||||||
for key, results in results_dict.items():
|
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
if enable_log:
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
|
||||||
# ref/hyp pairs.
|
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
|
||||||
)
|
|
||||||
test_set_wers[key] = wer
|
|
||||||
|
|
||||||
if enable_log:
|
|
||||||
logging.info(
|
|
||||||
"Wrote detailed error stats to {}".format(errs_filename)
|
|
||||||
)
|
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
||||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tWER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
|
||||||
note = ""
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
|
||||||
logging.info("Decoding started")
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
max_token_id = max(lexicon.tokens)
|
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
assert num_classes == params.num_tokens
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
|
||||||
params.lang_dir,
|
|
||||||
device=device,
|
|
||||||
sos_token="<sos/eos>",
|
|
||||||
eos_token="<sos/eos>",
|
|
||||||
)
|
|
||||||
sos_id = graph_compiler.sos_id
|
|
||||||
eos_id = graph_compiler.eos_id
|
|
||||||
|
|
||||||
if params.method == "ctc-decoding":
|
|
||||||
HLG = None
|
|
||||||
H = k2.ctc_topo(
|
|
||||||
max_token=max_token_id,
|
|
||||||
modified=False,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
|
||||||
else:
|
|
||||||
H = None
|
|
||||||
bpe_model = None
|
|
||||||
HLG = k2.Fsa.from_dict(
|
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
|
||||||
)
|
|
||||||
HLG = HLG.to(device)
|
|
||||||
assert HLG.requires_grad is False
|
|
||||||
|
|
||||||
if not hasattr(HLG, "lm_scores"):
|
|
||||||
HLG.lm_scores = HLG.scores.clone()
|
|
||||||
|
|
||||||
if params.method in (
|
|
||||||
"nbest-rescoring",
|
|
||||||
"whole-lattice-rescoring",
|
|
||||||
"attention-decoder",
|
|
||||||
):
|
|
||||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
|
||||||
logging.info("Loading G_4_gram.fst.txt")
|
|
||||||
logging.warning("It may take 8 minutes.")
|
|
||||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
|
||||||
first_word_disambig_id = lexicon.word_table["#0"]
|
|
||||||
|
|
||||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
|
||||||
# G.aux_labels is not needed in later computations, so
|
|
||||||
# remove it here.
|
|
||||||
del G.aux_labels
|
|
||||||
# CAUTION: The following line is crucial.
|
|
||||||
# Arcs entering the back-off state have label equal to #0.
|
|
||||||
# We have to change it to 0 here.
|
|
||||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
|
||||||
G = k2.Fsa.from_fsas([G]).to(device)
|
|
||||||
G = k2.arc_sort(G)
|
|
||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
|
||||||
else:
|
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
|
||||||
# Add epsilon self-loops to G as we will compose
|
|
||||||
# it with the whole lattice later
|
|
||||||
G = k2.add_epsilon_self_loops(G)
|
|
||||||
G = k2.arc_sort(G)
|
|
||||||
G = G.to(device)
|
|
||||||
|
|
||||||
# G.lm_scores is used to replace HLG.lm_scores during
|
|
||||||
# LM rescoring.
|
|
||||||
G.lm_scores = G.scores.clone()
|
|
||||||
else:
|
|
||||||
G = None
|
|
||||||
|
|
||||||
model = MaskedLmConformer(
|
|
||||||
num_classes=params.num_tokens,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
num_decoder_layers=params.num_decoder_layers,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.avg == 1:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
else:
|
|
||||||
start = params.epoch - params.avg + 1
|
|
||||||
filenames = []
|
|
||||||
for i in range(start, params.epoch + 1):
|
|
||||||
if start >= 0:
|
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
|
|
||||||
if params.export:
|
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
|
||||||
torch.save(
|
|
||||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
print("TODO: Implement me!")
|
|
||||||
# [ ] Add an option to use conformer lm for rescoring
|
|
||||||
# [ ] Load conformer_lm only when that options is activated
|
|
||||||
# [ ] Load conformer model
|
|
||||||
return
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
|
||||||
# If you want to skip test-clean, you have to skip
|
|
||||||
# it inside the for loop. That is, use
|
|
||||||
#
|
|
||||||
# if test_set == 'test-clean': continue
|
|
||||||
#
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
|
||||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
|
||||||
results_dict = decode_dataset(
|
|
||||||
dl=test_dl,
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
HLG=HLG,
|
|
||||||
H=H,
|
|
||||||
bpe_model=bpe_model,
|
|
||||||
word_table=lexicon.word_table,
|
|
||||||
G=G,
|
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_results(
|
|
||||||
params=params, test_set_name=test_set, results_dict=results_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -20,6 +20,11 @@ from typing import Dict, List, Optional, Union
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from icefall.lm.rescore import (
|
||||||
|
compute_alignment,
|
||||||
|
make_hyp_to_ref_map,
|
||||||
|
prepare_conformer_lm_inputs,
|
||||||
|
)
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
@ -224,6 +229,7 @@ class Nbest(object):
|
|||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
|
|
||||||
# Each utterance has `num_paths` paths but some of them transduces
|
# Each utterance has `num_paths` paths but some of them transduces
|
||||||
# to the same word sequence, so we need to remove repeated word
|
# to the same word sequence, so we need to remove repeated word
|
||||||
@ -889,3 +895,198 @@ def rescore_with_attention_decoder(
|
|||||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||||
ans[key] = best_path
|
ans[key] = best_path
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_conformer_lm(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
masked_lm_model: torch.nn.Module,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: Optional[torch.Tensor],
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
blank_id: int,
|
||||||
|
nbest_scale: float = 1.0,
|
||||||
|
ngram_lm_scale: Optional[float] = None,
|
||||||
|
attention_scale: Optional[float] = None,
|
||||||
|
masked_lm_scale: Optional[float] = None,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> Dict[str, k2.Fsa]:
|
||||||
|
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||||
|
an attention decoder to rescore them. The path with the highest score is
|
||||||
|
the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
An FsaVec with axes [utt][state][arc].
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the given lattice for rescoring.
|
||||||
|
model:
|
||||||
|
A transformer model. See the class "Transformer" in
|
||||||
|
conformer_ctc/transformer.py for its interface.
|
||||||
|
memory:
|
||||||
|
The encoder memory of the given model. It is the output of
|
||||||
|
the last torch.nn.TransformerEncoder layer in the given model.
|
||||||
|
Its shape is `(T, N, C)`.
|
||||||
|
memory_key_padding_mask:
|
||||||
|
The padding mask for memory with shape `(N, T)`.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to `lattice.scores`. A smaller value
|
||||||
|
leads to more unique paths at the risk of missing the correct path.
|
||||||
|
ngram_lm_scale:
|
||||||
|
Optional. It specifies the scale for n-gram LM scores.
|
||||||
|
attention_scale:
|
||||||
|
Optional. It specifies the scale for attention decoder scores.
|
||||||
|
masked_lm_scale:
|
||||||
|
Optional. It specifies the scale for conformer_lm scores.
|
||||||
|
Returns:
|
||||||
|
A dict of FsaVec, whose key contains a string
|
||||||
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
|
best decoding path for each utterance in the lattice.
|
||||||
|
"""
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
# nbest.fsa.scores are all 0s at this point
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa has its scores set.
|
||||||
|
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||||
|
assert hasattr(nbest.fsa, "lm_scores")
|
||||||
|
|
||||||
|
am_scores = nbest.compute_am_scores()
|
||||||
|
ngram_lm_scores = nbest.compute_lm_scores()
|
||||||
|
|
||||||
|
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||||
|
assert hasattr(nbest.fsa, "tokens")
|
||||||
|
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||||
|
|
||||||
|
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||||
|
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||||
|
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||||
|
|
||||||
|
if memory_key_padding_mask is not None:
|
||||||
|
# The shape of memory_key_padding_mask is (N, T), so we
|
||||||
|
# use axis=0 here.
|
||||||
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
|
0, path_to_utt_map
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expanded_memory_key_padding_mask = None
|
||||||
|
|
||||||
|
# remove axis corresponding to states.
|
||||||
|
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||||
|
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||||
|
tokens = tokens.remove_values_leq(0)
|
||||||
|
|
||||||
|
alignment = compute_alignment(tokens, nbest.shape)
|
||||||
|
(
|
||||||
|
masked_src_symbols,
|
||||||
|
src_symbols,
|
||||||
|
tgt_symbols,
|
||||||
|
src_key_padding_mask,
|
||||||
|
tgt_weights,
|
||||||
|
) = prepare_conformer_lm_inputs(
|
||||||
|
alignment,
|
||||||
|
bos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=blank_id,
|
||||||
|
unmasked_weight=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
||||||
|
src_symbols = src_symbols.to(torch.int64)
|
||||||
|
tgt_symbols = tgt_symbols.to(torch.int64)
|
||||||
|
|
||||||
|
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
||||||
|
masked_src_symbols, src_key_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
tgt_nll = masked_lm_model.decoder_nll(
|
||||||
|
masked_lm_memory,
|
||||||
|
masked_lm_pos_emb,
|
||||||
|
src_symbols,
|
||||||
|
tgt_symbols,
|
||||||
|
src_key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# nll means negative log-likelihood
|
||||||
|
# ll means log-likelihood
|
||||||
|
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
||||||
|
|
||||||
|
# Note: log-likelihood for those pairs that have identical src/tgt are 0
|
||||||
|
# since their tgt_weights is 0
|
||||||
|
|
||||||
|
# TODO(fangjun): Add documentation about why we do the following
|
||||||
|
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1))
|
||||||
|
tgt_ll_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=None,
|
||||||
|
row_ids=tgt_ll_shape_row_ids,
|
||||||
|
cached_tot_size=tgt_ll_shape_row_ids.numel(),
|
||||||
|
)
|
||||||
|
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
|
||||||
|
|
||||||
|
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
|
||||||
|
masked_lm_scores = ragged_tgt_ll.max()
|
||||||
|
|
||||||
|
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
||||||
|
token_ids = tokens.tolist()
|
||||||
|
|
||||||
|
nll = model.decoder_nll(
|
||||||
|
memory=expanded_memory,
|
||||||
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
assert nll.ndim == 2
|
||||||
|
assert nll.shape[0] == len(token_ids)
|
||||||
|
|
||||||
|
attention_scores = -nll.sum(dim=1)
|
||||||
|
|
||||||
|
if ngram_lm_scale is None:
|
||||||
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
|
if attention_scale is None:
|
||||||
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
|
if masked_lm_scale is None:
|
||||||
|
masked_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
masked_lm_scale_list = [masked_lm_scale]
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
for n_scale in ngram_lm_scale_list:
|
||||||
|
for a_scale in attention_scale_list:
|
||||||
|
for m_scale in masked_lm_scale_list:
|
||||||
|
tot_scores = (
|
||||||
|
am_scores.values
|
||||||
|
+ n_scale * ngram_lm_scores.values
|
||||||
|
+ a_scale * attention_scores
|
||||||
|
+ m_scale * masked_lm_scores
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_masked_lm_scale_{m_scale}" # noqa
|
||||||
|
ans[key] = best_path
|
||||||
|
return ans
|
||||||
|
@ -32,6 +32,8 @@ We can generate the following inputs for the conformer LM model from `tokens`:
|
|||||||
- src
|
- src
|
||||||
- tgt
|
- tgt
|
||||||
by using `k2.levenshtein_alignment`.
|
by using `k2.levenshtein_alignment`.
|
||||||
|
|
||||||
|
TODO(fangjun): Add more doc about rescoring with masked conformer-lm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
@ -39,8 +41,6 @@ from typing import Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from icefall.decode import Nbest
|
|
||||||
|
|
||||||
|
|
||||||
def make_key_padding_mask(lengths: torch.Tensor):
|
def make_key_padding_mask(lengths: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@ -236,7 +236,7 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
|||||||
>>> tokens
|
>>> tokens
|
||||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
|
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
|
||||||
>>> make_repeat(tokens)
|
>>> make_repeat(tokens)
|
||||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ]
|
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa
|
||||||
|
|
||||||
TODO: Add documentation.
|
TODO: Add documentation.
|
||||||
|
|
||||||
@ -300,6 +300,7 @@ def prepare_conformer_lm_inputs(
|
|||||||
alignments:
|
alignments:
|
||||||
It is computed by :func:`compute_alignment`
|
It is computed by :func:`compute_alignment`
|
||||||
"""
|
"""
|
||||||
|
device = alignment.device
|
||||||
# alignment.arcs.shape has axes [fsa][state][arc]
|
# alignment.arcs.shape has axes [fsa][state][arc]
|
||||||
# we remove axis 1, i.e., state, here
|
# we remove axis 1, i.e., state, here
|
||||||
labels_shape = alignment.arcs.shape().remove_axis(1)
|
labels_shape = alignment.arcs.shape().remove_axis(1)
|
||||||
@ -337,6 +338,7 @@ def prepare_conformer_lm_inputs(
|
|||||||
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
|
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
|
||||||
fill_value=1,
|
fill_value=1,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# find unmasked positions
|
# find unmasked positions
|
||||||
@ -359,52 +361,3 @@ def prepare_conformer_lm_inputs(
|
|||||||
src_key_padding_mask,
|
src_key_padding_mask,
|
||||||
weight,
|
weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def conformer_lm_rescore(
|
|
||||||
nbest: Nbest,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
bos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
blank_id: int,
|
|
||||||
unmasked_weight: float = 0.25,
|
|
||||||
# TODO: add other arguments if needed
|
|
||||||
) -> k2.RaggedTensor:
|
|
||||||
"""Rescore an Nbest object with a conformer_lm model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nbest:
|
|
||||||
It contains linear FSAs to be rescored.
|
|
||||||
model:
|
|
||||||
A conformer lm model. See "conformer_lm/train.py"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a ragged tensor containing scores for each path
|
|
||||||
contained in the nbest. Its shape equals to `nbest.shape`.
|
|
||||||
"""
|
|
||||||
assert hasattr(nbest.fsa, "tokens")
|
|
||||||
utt_path_shape = nbest.shape
|
|
||||||
# nbest.fsa.arcs.shape() has axes [path][state][arc]
|
|
||||||
# We remove the state axis here
|
|
||||||
path_token_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
|
||||||
|
|
||||||
path_token = k2.RaggedTensor(path_token_shape, nbest.fsa.tokens)
|
|
||||||
path_token = path_token.remove_values_leq(0)
|
|
||||||
|
|
||||||
alignment = compute_alignment(path_token, utt_path_shape)
|
|
||||||
(
|
|
||||||
masked_src,
|
|
||||||
src,
|
|
||||||
tgt,
|
|
||||||
src_key_padding_mask,
|
|
||||||
weight,
|
|
||||||
) = prepare_conformer_lm_inputs(
|
|
||||||
alignment,
|
|
||||||
bos_id=bos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
blank_id=blank_id,
|
|
||||||
unmasked_weight=unmasked_weight,
|
|
||||||
)
|
|
||||||
return masked_src, src, tgt, src_key_padding_mask, weight
|
|
||||||
# TODO: pass masked_src, src, tgt, src_key_padding_mask, and weight
|
|
||||||
# to the given model
|
|
||||||
|
@ -17,12 +17,10 @@
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from icefall.decode import Nbest
|
|
||||||
from icefall.lm.rescore import (
|
from icefall.lm.rescore import (
|
||||||
add_bos,
|
add_bos,
|
||||||
add_eos,
|
add_eos,
|
||||||
compute_alignment,
|
compute_alignment,
|
||||||
conformer_lm_rescore,
|
|
||||||
make_hyp_to_ref_map,
|
make_hyp_to_ref_map,
|
||||||
make_repeat,
|
make_repeat,
|
||||||
make_repeat_map,
|
make_repeat_map,
|
||||||
@ -45,6 +43,7 @@ def test_add_eos():
|
|||||||
expected = k2.RaggedTensor(
|
expected = k2.RaggedTensor(
|
||||||
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
||||||
)
|
)
|
||||||
|
assert str(ragged_eos) == str(expected)
|
||||||
|
|
||||||
|
|
||||||
def test_pad():
|
def test_pad():
|
||||||
@ -71,7 +70,7 @@ def test_make_hyp_to_ref_map():
|
|||||||
repeat_map = make_hyp_to_ref_map(row_splits)
|
repeat_map = make_hyp_to_ref_map(row_splits)
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
|
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
|
||||||
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map)
|
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa
|
||||||
# fmt: on
|
# fmt: on
|
||||||
assert torch.all(torch.eq(repeat_map, expected))
|
assert torch.all(torch.eq(repeat_map, expected))
|
||||||
|
|
||||||
@ -82,8 +81,8 @@ def test_make_repeat_map():
|
|||||||
repeat_map = make_repeat_map(row_splits)
|
repeat_map = make_repeat_map(row_splits)
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
|
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
|
||||||
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6,
|
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa
|
||||||
3, 4, 5, 6]).to(repeat_map)
|
3, 4, 5, 6]).to(repeat_map) # noqa
|
||||||
# fmt: on
|
# fmt: on
|
||||||
assert torch.all(torch.eq(repeat_map, expected))
|
assert torch.all(torch.eq(repeat_map, expected))
|
||||||
|
|
||||||
@ -132,27 +131,6 @@ def test_compute_alignment():
|
|||||||
# print("weight", weight)
|
# print("weight", weight)
|
||||||
|
|
||||||
|
|
||||||
def test_conformer_lm_rescore():
|
|
||||||
path00 = k2.linear_fsa([1, 2, 0, 3, 0, 5])
|
|
||||||
path01 = k2.linear_fsa([1, 0, 5, 0])
|
|
||||||
path10 = k2.linear_fsa([9, 8, 0, 3, 0, 2])
|
|
||||||
path11 = k2.linear_fsa([9, 8, 0, 0, 3, 2])
|
|
||||||
path12 = k2.linear_fsa([9, 0, 8, 4, 0, 2, 3])
|
|
||||||
|
|
||||||
fsa = k2.Fsa.from_fsas([path00, path01, path10, path11, path12])
|
|
||||||
fsa.tokens = fsa.labels.clone()
|
|
||||||
shape = k2.RaggedShape("[[x x] [x x x]]")
|
|
||||||
nbest = Nbest(fsa, shape)
|
|
||||||
masked_src, src, tgt, src_key_padding_mask, weight = conformer_lm_rescore(
|
|
||||||
nbest, model=None, bos_id=10, eos_id=20, blank_id=0
|
|
||||||
)
|
|
||||||
print("masked src", masked_src)
|
|
||||||
print("src", src)
|
|
||||||
print("tgt", tgt)
|
|
||||||
print("src_key_padding_mask", src_key_padding_mask)
|
|
||||||
print("weight", weight)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
test_add_bos()
|
test_add_bos()
|
||||||
test_add_eos()
|
test_add_eos()
|
||||||
@ -161,7 +139,6 @@ def main():
|
|||||||
test_make_hyp_to_ref_map()
|
test_make_hyp_to_ref_map()
|
||||||
test_make_repeat()
|
test_make_repeat()
|
||||||
test_compute_alignment()
|
test_compute_alignment()
|
||||||
test_conformer_lm_rescore()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user