mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Minor fixes to the decode.py
This commit is contained in:
parent
8a94476fd9
commit
57c1d762b6
@ -58,6 +58,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -76,6 +77,8 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -211,6 +214,26 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -222,6 +245,7 @@ def decode_one_batch(
|
|||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -285,6 +309,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hyp_tokens = []
|
hyp_tokens = []
|
||||||
@ -324,7 +349,12 @@ def decode_one_batch(
|
|||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
key = f"beam_size_{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
key += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
key += "-no-context-words"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -333,6 +363,7 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
token_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -377,6 +408,7 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
token_table=token_table,
|
token_table=token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -407,16 +439,17 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
# we compute CER for aishell dataset.
|
||||||
|
results_char = []
|
||||||
|
for res in results:
|
||||||
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
|
|
||||||
|
store_transcripts(filename=recog_path, texts=results_char)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
# we compute CER for aishell dataset.
|
|
||||||
results_char = []
|
|
||||||
for res in results:
|
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -457,6 +490,12 @@ def main():
|
|||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -470,6 +509,10 @@ def main():
|
|||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-contexts-words"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -490,6 +533,11 @@ def main():
|
|||||||
params.blank_id = 0
|
params.blank_id = 0
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -586,6 +634,19 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts_text = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts_text.append(line.strip())
|
||||||
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -608,6 +669,7 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
token_table=lexicon.token_table,
|
token_table=lexicon.token_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
|||||||
@ -15,7 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|||||||
@ -131,8 +131,6 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifst
|
|
||||||
import graphviz
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -576,7 +574,10 @@ def decode_one_batch(
|
|||||||
return {key: (hyps, timestamps)}
|
return {key: (hyps, timestamps)}
|
||||||
else:
|
else:
|
||||||
key = f"beam_size_{params.beam_size}"
|
key = f"beam_size_{params.beam_size}"
|
||||||
key += f"-context-score-{params.context_score}"
|
if params.has_contexts:
|
||||||
|
key += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
key += "-no-context-words"
|
||||||
return {key: (hyps, timestamps)}
|
return {key: (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
@ -626,7 +627,7 @@ def decode_dataset(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 50
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 1
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -759,6 +760,12 @@ def main():
|
|||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -781,7 +788,10 @@ def main():
|
|||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
params.suffix += f"-context-score-{params.context_score}"
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-context-words"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -938,14 +948,8 @@ def main():
|
|||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
test_book_cuts = librispeech.test_book_cuts()
|
test_sets = ["test-clean", "test-other"]
|
||||||
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
test_book2_cuts = librispeech.test_book2_cuts()
|
|
||||||
test_book2_dl = librispeech.test_dataloaders(test_book2_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test-book", "test-book2", "test-clean", "test-other"]
|
|
||||||
test_dl = [test_book_dl, test_book2_dl, test_clean_dl, test_other_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
|||||||
@ -389,7 +389,6 @@ class LibriSpeechAsrDataModule:
|
|||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts,
|
cuts,
|
||||||
num_buckets=2,
|
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
@ -468,25 +467,6 @@ class LibriSpeechAsrDataModule:
|
|||||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_book_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test-books cuts")
|
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz")
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_book_test_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test-books cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "libri_book_test_feats.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_book2_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test-books2 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "libri_books2_feats.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-other cuts")
|
logging.info("About to get test-other cuts")
|
||||||
|
|||||||
@ -396,21 +396,14 @@ class WenetSpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV2.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_net_cuts(self) -> List[CutSet]:
|
def test_net_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get TEST_NET cuts")
|
logging.info("About to get TEST_NET cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET2.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_meeting_cuts(self) -> List[CutSet]:
|
def test_meeting_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get TEST_MEETING cuts")
|
logging.info("About to get TEST_MEETING cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
|
||||||
self.args.manifest_dir / "cuts_TEST_MEETING2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_car_cuts(self) -> List[CutSet]:
|
|
||||||
logging.info("About to get TEST_CAR cuts")
|
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "car_test_feats.jsonl.gz")
|
|
||||||
|
|||||||
@ -533,9 +533,12 @@ def decode_one_batch(
|
|||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
key = f"beam_size_{params.beam_size}"
|
||||||
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
|
if params.has_contexts:
|
||||||
}
|
key += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
key += "-no-context-words"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -674,6 +677,12 @@ def main():
|
|||||||
"modified_beam_search_lm_shallow_fusion",
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
"modified_beam_search_LODR",
|
"modified_beam_search_LODR",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
@ -683,7 +692,10 @@ def main():
|
|||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
params.suffix += f"-context-score-{params.context_score}"
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-contexts-words"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -851,14 +863,10 @@ def main():
|
|||||||
|
|
||||||
if params.decoding_method == "modified_beam_search":
|
if params.decoding_method == "modified_beam_search":
|
||||||
if os.path.exists(params.context_file):
|
if os.path.exists(params.context_file):
|
||||||
contexts = []
|
contexts_text = []
|
||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
context_list = graph_compiler.texts_to_ids(line.strip())
|
contexts_text.append(line.strip())
|
||||||
tmp = []
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
for context in context_list:
|
|
||||||
for x in context:
|
|
||||||
tmp.append(x)
|
|
||||||
contexts.append(tmp)
|
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build(contexts)
|
context_graph.build(contexts)
|
||||||
else:
|
else:
|
||||||
@ -882,11 +890,8 @@ def main():
|
|||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||||
|
|
||||||
test_car_cuts = wenetspeech.test_car_cuts()
|
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||||
test_car_dl = wenetspeech.test_dataloaders(test_car_cuts)
|
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
||||||
|
|
||||||
test_sets = ["CAR", "TEST_NET", "DEV", "TEST_MEETING"]
|
|
||||||
test_dls = [test_car_dl, test_net_dl, dev_dl, test_meeting_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user