mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
fix style
This commit is contained in:
parent
90dac69bc5
commit
e32bda6a7b
@ -45,33 +45,41 @@ import math
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Callable
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import k2
|
import k2
|
||||||
from lhotse import load_manifest_lazy
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BertTokenizer, BertModel
|
|
||||||
from asr_datamodule import LibriHeavyAsrDataModule
|
from asr_datamodule import LibriHeavyAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_with_context,
|
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
greedy_search_batch_with_context,
|
greedy_search_batch_with_context,
|
||||||
|
greedy_search_with_context,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||||
from utils import get_facebook_biasing_list
|
from lhotse import load_manifest_lazy
|
||||||
from text_normalization import train_text_normalization, ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
|
from text_normalization import (
|
||||||
|
lower_all_char,
|
||||||
|
lower_only_alpha,
|
||||||
|
ref_text_normalization,
|
||||||
|
remove_non_alphabetic,
|
||||||
|
train_text_normalization,
|
||||||
|
upper_all_char,
|
||||||
|
upper_only_alpha,
|
||||||
|
)
|
||||||
from train_bert_encoder_with_style import (
|
from train_bert_encoder_with_style import (
|
||||||
|
_encode_texts_as_bytes_with_tokenizer,
|
||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
get_params,
|
get_params,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
get_transducer_model,
|
get_transducer_model,
|
||||||
_encode_texts_as_bytes_with_tokenizer,
|
|
||||||
)
|
)
|
||||||
|
from transformers import BertModel, BertTokenizer
|
||||||
|
from utils import get_facebook_biasing_list
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -89,11 +97,13 @@ from icefall.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -144,7 +154,7 @@ def get_parser():
|
|||||||
default="pruned_transducer_stateless7/exp",
|
default="pruned_transducer_stateless7/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-dir",
|
"--log-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -260,21 +270,20 @@ def get_parser():
|
|||||||
Used only when the decoding method is fast_beam_search_nbest,
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-manifest",
|
"--input-manifest",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The input manifest to be decoded"
|
help="The input manifest to be decoded",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-manifest",
|
"--output-manifest",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Where to store the output manifest (directory)"
|
help="Where to store the output manifest (directory)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-pre-text",
|
"--use-pre-text",
|
||||||
@ -282,19 +291,19 @@ def get_parser():
|
|||||||
default=True,
|
default=True,
|
||||||
help="Use pre-text is available during decoding",
|
help="Use pre-text is available during decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-style-prompt",
|
"--use-style-prompt",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Use style prompt when evaluation"
|
help="Use style prompt when evaluation",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-context-embedding",
|
"--use-context-embedding",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Use context fuser when evaluation"
|
help="Use context fuser when evaluation",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -310,43 +319,44 @@ def get_parser():
|
|||||||
default=True,
|
default=True,
|
||||||
help="Reports CER. By default, only reports WER",
|
help="Reports CER. By default, only reports WER",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--style-text-transform",
|
"--style-text-transform",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||||
default="mixed-punc",
|
default="mixed-punc",
|
||||||
help="The style of style prompt, i.e style_text"
|
help="The style of style prompt, i.e style_text",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pre-text-transform",
|
"--pre-text-transform",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||||
default="mixed-punc",
|
default="mixed-punc",
|
||||||
help="The style of content prompt, i.e pre_text"
|
help="The style of content prompt, i.e pre_text",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ls-test-set",
|
"--use-ls-test-set",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Use librispeech test set for evaluation."
|
help="Use librispeech test set for evaluation.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ls-context-list",
|
"--use-ls-context-list",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="If use a fixed context list for LibriSpeech decoding"
|
help="If use a fixed context list for LibriSpeech decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||||
"""Apply transform to a list of text. By default, the text are in
|
"""Apply transform to a list of text. By default, the text are in
|
||||||
ground truth format, i.e mixed-punc.
|
ground truth format, i.e mixed-punc.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -366,7 +376,7 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
|||||||
return [lower_all_char(s) for s in text]
|
return [lower_all_char(s) for s in text]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -421,37 +431,43 @@ def decode_one_batch(
|
|||||||
cuts = batch["supervisions"]["cut"]
|
cuts = batch["supervisions"]["cut"]
|
||||||
cut_ids = [c.supervisions[0].id for c in cuts]
|
cut_ids = [c.supervisions[0].id for c in cuts]
|
||||||
batch_size = feature.size(0)
|
batch_size = feature.size(0)
|
||||||
|
|
||||||
# get pre_text
|
# get pre_text
|
||||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||||
pre_texts = batch["supervisions"]["text"] # use the ground truth ref text as pre_text
|
pre_texts = batch["supervisions"][
|
||||||
|
"text"
|
||||||
|
] # use the ground truth ref text as pre_text
|
||||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||||
else:
|
else:
|
||||||
pre_texts = ["" for _ in range(batch_size)]
|
pre_texts = ["" for _ in range(batch_size)]
|
||||||
|
|
||||||
if params.use_ls_context_list:
|
if params.use_ls_context_list:
|
||||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||||
|
|
||||||
# get style_text
|
# get style_text
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
style_texts = batch["supervisions"].get(
|
||||||
|
"style_text", [fixed_sentence for _ in range(batch_size)]
|
||||||
|
)
|
||||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||||
else:
|
else:
|
||||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||||
|
|
||||||
# Get the text embedding input
|
# Get the text embedding input
|
||||||
if params.use_pre_text or params.use_style_prompt:
|
if params.use_pre_text or params.use_style_prompt:
|
||||||
|
|
||||||
# apply style transform to the pre_text and style_text
|
# apply style transform to the pre_text and style_text
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
# pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
style_texts = _apply_style_transform(
|
||||||
|
style_texts, params.style_text_transform
|
||||||
|
)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
# Use tokenizer to prepare input for text encoder
|
# Use tokenizer to prepare input for text encoder
|
||||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||||
pre_texts=pre_texts,
|
pre_texts=pre_texts,
|
||||||
@ -459,11 +475,11 @@ def decode_one_batch(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
memory, memory_key_padding_mask = model.encode_text(
|
memory, memory_key_padding_mask = model.encode_text(
|
||||||
encoded_inputs=encoded_inputs,
|
encoded_inputs=encoded_inputs,
|
||||||
style_lens=style_lens,
|
style_lens=style_lens,
|
||||||
) # (T,B,C)
|
) # (T,B,C)
|
||||||
else:
|
else:
|
||||||
memory = None
|
memory = None
|
||||||
memory_key_padding_mask = None
|
memory_key_padding_mask = None
|
||||||
@ -487,10 +503,7 @@ def decode_one_batch(
|
|||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if (
|
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
params.decoding_method == "greedy_search"
|
|
||||||
and params.max_sym_per_frame == 1
|
|
||||||
):
|
|
||||||
if memory is None or not params.use_context_embedding:
|
if memory is None or not params.use_context_embedding:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
@ -498,9 +511,11 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
|
||||||
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
|
context = model.context_fuser(
|
||||||
context = model.joiner.context_proj(context) # (N,C)
|
memory, padding_mask=memory_key_padding_mask
|
||||||
|
) # (N,C)
|
||||||
|
context = model.joiner.context_proj(context) # (N,C)
|
||||||
hyp_tokens = greedy_search_batch_with_context(
|
hyp_tokens = greedy_search_batch_with_context(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -533,19 +548,13 @@ def decode_one_batch(
|
|||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cur_context = context[i:i+1, :]
|
cur_context = context[i : i + 1, :]
|
||||||
hyp = greedy_search_with_context(
|
hyp = greedy_search_with_context(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
context=cur_context,
|
context=cur_context,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
|
||||||
hyp = beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -608,13 +617,15 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
|
texts = batch["supervisions"][
|
||||||
|
"text"
|
||||||
|
] # By default, this should be in mixed-punc format
|
||||||
|
|
||||||
# the style of ref_text should match style_text
|
# the style of ref_text should match style_text
|
||||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||||
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
@ -645,9 +656,7 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -677,7 +686,9 @@ def save_results(
|
|||||||
|
|
||||||
if params.compute_CER:
|
if params.compute_CER:
|
||||||
# Write CER statistics
|
# Write CER statistics
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||||
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||||
@ -695,9 +706,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -712,9 +721,7 @@ def save_results(
|
|||||||
|
|
||||||
if params.compute_CER:
|
if params.compute_CER:
|
||||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_cers:
|
for key, val in test_set_cers:
|
||||||
@ -740,65 +747,69 @@ def add_decoding_result_to_manifest(
|
|||||||
for items in value:
|
for items in value:
|
||||||
id, ref, hyp = items
|
id, ref, hyp = items
|
||||||
new_ans[id] = " ".join(hyp)
|
new_ans[id] = " ".join(hyp)
|
||||||
|
|
||||||
def _add_decoding(c):
|
def _add_decoding(c):
|
||||||
key = c.supervisions[0].id
|
key = c.supervisions[0].id
|
||||||
c.supervisions[0].texts.append(new_ans[key])
|
c.supervisions[0].texts.append(new_ans[key])
|
||||||
return c
|
return c
|
||||||
|
|
||||||
in_manifest = in_manifest.map(_add_decoding)
|
in_manifest = in_manifest.map(_add_decoding)
|
||||||
logging.info(f"Saving manifest to {out_manifest}")
|
logging.info(f"Saving manifest to {out_manifest}")
|
||||||
in_manifest.to_file(out_manifest)
|
in_manifest.to_file(out_manifest)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriHeavyAsrDataModule.add_arguments(parser)
|
LibriHeavyAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
cuts = load_manifest_lazy(args.input_manifest)
|
cuts = load_manifest_lazy(args.input_manifest)
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
splitted_cuts = cuts.split(num_splits=world_size)
|
splitted_cuts = cuts.split(num_splits=world_size)
|
||||||
mp.spawn(run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True)
|
mp.spawn(
|
||||||
|
run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
run(rank=0, world_size=1, args=args, cuts=cuts)
|
run(rank=0, world_size=1, args=args, cuts=cuts)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run(rank, world_size, args, cuts):
|
def run(rank, world_size, args, cuts):
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.use_pre_text:
|
if params.use_pre_text:
|
||||||
params.suffix += f"-pre-text-{params.pre_text_transform}"
|
params.suffix += f"-pre-text-{params.pre_text_transform}"
|
||||||
|
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||||
|
|
||||||
params.suffix += f"-{rank}"
|
params.suffix += f"-{rank}"
|
||||||
|
|
||||||
world_size = params.world_size
|
world_size = params.world_size
|
||||||
|
|
||||||
params.output_manifest = Path(params.output_manifest)
|
params.output_manifest = Path(params.output_manifest)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
cuts = cuts[rank]
|
cuts = cuts[rank]
|
||||||
out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz"
|
out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz"
|
||||||
else:
|
else:
|
||||||
out_name = params.output_manifest / f"with_decoding.jsonl.gz"
|
out_name = params.output_manifest / "with_decoding.jsonl.gz"
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
|
setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
@ -819,9 +830,9 @@ def run(rank, world_size, args, cuts):
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -848,9 +859,9 @@ def run(rank, world_size, args, cuts):
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -909,14 +920,16 @@ def run(rank, world_size, args, cuts):
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
libriheavy = LibriHeavyAsrDataModule(args)
|
libriheavy = LibriHeavyAsrDataModule(args)
|
||||||
|
|
||||||
dl = libriheavy.valid_dataloaders(cuts, text_sampling_func=naive_triplet_text_sampling)
|
dl = libriheavy.valid_dataloaders(
|
||||||
|
cuts, text_sampling_func=naive_triplet_text_sampling
|
||||||
|
)
|
||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
test_dl = [dl]
|
test_dl = [dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
biasing_dict = None
|
biasing_dict = None
|
||||||
|
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
@ -933,7 +946,7 @@ def run(rank, world_size, args, cuts):
|
|||||||
# test_set_name=test_set,
|
# test_set_name=test_set,
|
||||||
# results_dict=results_dict,
|
# results_dict=results_dict,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
add_decoding_result_to_manifest(
|
add_decoding_result_to_manifest(
|
||||||
in_manifest=cuts,
|
in_manifest=cuts,
|
||||||
out_manifest=out_name,
|
out_manifest=out_name,
|
||||||
@ -942,6 +955,7 @@ def run(rank, world_size, args, cuts):
|
|||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
# torch.set_num_threads(1)
|
# torch.set_num_threads(1)
|
||||||
# torch.set_num_interop_threads(1)
|
# torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user