Add decoding script.

This commit is contained in:
Fangjun Kuang 2021-12-13 19:49:50 +08:00
parent 73ba843d0a
commit e38f04e70f
2 changed files with 122 additions and 431 deletions

View File

@ -428,8 +428,6 @@ def decode_dataset(
The first is the reference transcript, and the second is the The first is the reference transcript, and the second is the
predicted result. predicted result.
""" """
results = []
num_cuts = 0 num_cuts = 0
try: try:

View File

@ -20,31 +20,22 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from transducer.beam_search import greedy_search
from transducer.conformer import Conformer
from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
@ -72,76 +63,18 @@ def get_parser():
"'--epoch'. ", "'--epoch'. ",
) )
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="conformer_ctc/exp", default="transducer/exp",
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_500", default="data/lang_bpe_500/bpe.model",
help="The lang dir", help="Path to the BPE model",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
) )
return parser return parser
@ -151,250 +84,138 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
# parameters for conformer # parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"feature_dim": 80, # decoder params
"nhead": 8, "decoder_embedding_dim": 1024,
"attention_dim": 512, "num_decoder_layers": 4,
"num_decoder_layers": 6, "decoder_hidden_dim": 512,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
return params return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: Optional[k2.Fsa], sp: spm.SentencePieceProcessor,
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict, batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
- key: It indicates the setting used for decoding. For example, - key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`. if greedy_search is used, it would be "greedy_search"
If LM rescoring is used, the key is the string `lm_scale_xxx`, If beam search with a beam size of 7 is used, it would be
where `xxx` is the value of `lm_scale`. An example key is "beam_7"
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to - value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch. utterance in the given batch.
Args: Args:
params: params:
It's the return value of :func:`get_params`. It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model: model:
The neural model. The neural model.
HLG: sp:
The decoding graph. Used only when params.method is NOT ctc-decoding. The BPE model.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None. the returned dict.
""" """
if HLG is not None: device = model.device
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) encoder_out, encoder_out_lens = model.encoder(
# nnet_output is (N, T, C) x=feature, x_lens=feature_lens
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
) )
hyps = []
batch_size = encoder_out.size(0)
if params.method == "ctc-decoding": for i in range(batch_size):
best_path = one_best_decoding( # fmt: off
lattice=lattice, use_double_scores=params.use_double_scores encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
) # fmt: on
# Note: `best_path.aux_labels` contains token IDs, not word IDs hyp = greedy_search(model=model, encoder_out=encoder_out_i)
# since we are using H, not HLG here. hyps.append(sp.decode(hyp).split())
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...] return {"greedy_search": hyps}
hyps = bpe_model.decode(token_ids) # TODO: Implement beam search
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: Optional[k2.Fsa], sp: spm.SentencePieceProcessor,
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -405,31 +226,15 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
HLG: sp:
The decoding graph. Used only when params.method is NOT ctc-decoding. The BPE model.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns: Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "lm_scale_0.7" if LM rescoring is used. is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements: Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the The first is the reference transcript, and the second is the
predicted result. predicted result.
""" """
results = []
num_cuts = 0 num_cuts = 0
try: try:
@ -444,37 +249,18 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
HLG=HLG, sp=sp,
H=H,
bpe_model=bpe_model,
batch=batch, batch=batch,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
) )
if hyps_dict is not None: for name, hyps in hyps_dict.items():
for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words)) this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[name].extend(this_batch)
else:
assert (
len(results) > 0
), "It should not decode to empty in the first batch!"
this_batch = []
hyp_words = []
for ref_text in texts:
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)
@ -492,16 +278,10 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -509,14 +289,11 @@ def save_results(
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log f, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@ -539,113 +316,31 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") setup_logger(f"{params.exp_dir}/log-decode")
logging.info("Decoding started") logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
logging.info(f"device: {device}") logging.info(f"Device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler( sp = spm.SentencePieceProcessor()
params.lang_dir, sp.load(params.bpe_model)
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
if params.method == "ctc-decoding": # <blk> and <sos/eos> are defined in local/train_bpe_model.py
HLG = None params.blank_id = sp.piece_to_id("<blk>")
H = k2.ctc_topo( params.sos_id = sp.piece_to_id("<sos/eos>")
max_token=max_token_id, params.vocab_size = sp.get_piece_size()
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"): logging.info(params)
HLG.lm_scores = HLG.scores.clone()
if params.method in ( logging.info("About to create model")
"nbest-rescoring", model = get_transducer_model(params)
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
@ -661,6 +356,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -680,17 +377,13 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
HLG=HLG, sp=sp,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
) )
save_results( save_results(
params=params, test_set_name=test_set, results_dict=results_dict params=params,
test_set_name=test_set,
results_dict=results_dict,
) )
logging.info("Done!") logging.info("Done!")