This commit is contained in:
marcoyang1998 2023-09-19 18:38:56 +08:00
parent bea1bd295f
commit 6579800720
3 changed files with 368 additions and 357 deletions

View File

@ -72,16 +72,12 @@ class LibriHeavyAsrDataModule:
self.args = args self.args = args
if args.use_context_list: if args.use_context_list:
from dataset2 import PromptASRDataset
assert args.rare_word_file is not None assert args.rare_word_file is not None
with open(args.rare_word_file, "r") as f: with open(args.rare_word_file, "r") as f:
self.rare_word_list = ( self.rare_word_list = (
f.read().lower().split() f.read().lower().split()
) # Use lower-cased for easier style transform ) # Use lower-cased for easier style transform
else: else:
from dataset import PromptASRDataset
self.rare_word_list = None self.rare_word_list = None
@classmethod @classmethod

View File

@ -20,22 +20,55 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7/decode.py \ ./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./zipformer_prompt_asr/exp \
--max-duration 600 \ --max-duration 1000 \
--decoding-method greedy_search --decoding-method greedy_search
(2) modified beam search (2) modified beam search
./pruned_transducer_stateless7/decode.py \ ./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./zipformer_prompt_asr/exp \
--max-duration 600 \ --max-duration 1000 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(3) Decode LibriSpeech
./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer_prompt_asr/exp \
--max-duration 1000 \
--decoding-method modified_beam_search \
--use-ls-test-set True \
--beam-size 4
(4) Decode LibriSpeech + biasing list
biasing_list=100 # could also be 0
./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer_prompt_asr/exp \
--max-duration 1000 \
--decoding-method modified_beam_search \
--beam-size 4 \\
--use-ls-test-set True \
--use-ls-context-list True \
--biasing-level utterance \
--ls-distractors $biasing_list \
--post-normalization True \
--text-encoder-type BERT \
--max-prompt-lens 1000 \
--style-text-transform mixed-punc \
--pre-text-transform mixed-punc
""" """
@ -45,40 +78,34 @@ import math
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Callable from typing import Callable, Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import BertTokenizer, BertModel
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import ( from beam_search import greedy_search, greedy_search_batch, modified_beam_search
greedy_search,
greedy_search_with_context,
greedy_search_batch,
greedy_search_batch_with_context,
modified_beam_search,
)
from dataset import naive_triplet_text_sampling, random_shuffle_subset from dataset import naive_triplet_text_sampling, random_shuffle_subset
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
from ls_text_normalization import word_normalization from ls_text_normalization import word_normalization
from text_normalization import ( from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
upper_all_char,
lower_all_char, lower_all_char,
lower_only_alpha, lower_only_alpha,
ref_text_normalization,
remove_non_alphabetic,
train_text_normalization, train_text_normalization,
upper_all_char,
upper_only_alpha,
) )
from train_bert_encoder_with_style import ( from train_bert_encoder import (
_encode_texts_as_bytes_with_tokenizer,
add_model_arguments, add_model_arguments,
get_params, get_params,
get_tokenizer, get_tokenizer,
get_transducer_model, get_transducer_model,
_encode_texts_as_bytes_with_tokenizer,
) )
from transformers import BertModel, BertTokenizer
from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -87,15 +114,11 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
)
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -267,7 +290,7 @@ def get_parser():
"--use-style-prompt", "--use-style-prompt",
type=str2bool, type=str2bool,
default=True, default=True,
help="Use style prompt when evaluation" help="Use style prompt when evaluation",
) )
parser.add_argument( parser.add_argument(
@ -276,13 +299,6 @@ def get_parser():
default=1000, default=1000,
) )
parser.add_argument(
"--use-context-embedding",
type=str2bool,
default=False,
help="Use context fuser when evaluation"
)
parser.add_argument( parser.add_argument(
"--post-normalization", "--post-normalization",
type=str2bool, type=str2bool,
@ -290,16 +306,10 @@ def get_parser():
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
) )
parser.add_argument(
"--long-audio-recog",
type=str2bool,
default=False,
)
parser.add_argument( parser.add_argument(
"--compute-CER", "--compute-CER",
type=str2bool, type=str2bool,
default=True, default=False,
help="Reports CER. By default, only reports WER", help="Reports CER. By default, only reports WER",
) )
@ -308,7 +318,7 @@ def get_parser():
type=str, type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc", default="mixed-punc",
help="The style of style prompt, i.e style_text" help="The style of style prompt, i.e style_text",
) )
parser.add_argument( parser.add_argument(
@ -316,21 +326,21 @@ def get_parser():
type=str, type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc", default="mixed-punc",
help="The style of content prompt, i.e pre_text" help="The style of content prompt, i.e pre_text",
) )
parser.add_argument( parser.add_argument(
"--use-ls-test-set", "--use-ls-test-set",
type=str2bool, type=str2bool,
default=False, default=False,
help="Use librispeech test set for evaluation." help="Use librispeech test set for evaluation.",
) )
parser.add_argument( parser.add_argument(
"--use-ls-context-list", "--use-ls-context-list",
type=str2bool, type=str2bool,
default=False, default=False,
help="If use a fixed context list for LibriSpeech decoding" help="If use a fixed context list for LibriSpeech decoding",
) )
parser.add_argument( parser.add_argument(
@ -344,13 +354,14 @@ def get_parser():
"--ls-distractors", "--ls-distractors",
type=int, type=int,
default=0, default=0,
help="The number of distractors into context list for LibriSpeech decoding" help="The number of distractors into context list for LibriSpeech decoding",
) )
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]: def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in """Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc. ground truth format, i.e mixed-punc.
@ -378,7 +389,7 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
tokenizer, tokenizer: spm.SentencePieceProcessor,
batch: dict, batch: dict,
biasing_dict: dict = None, biasing_dict: dict = None,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
@ -401,10 +412,15 @@ def decode_one_batch(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
tokenizer:
Tokenizer for the text encoder
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
biasing_dict:
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
of biasing words (separated with space)
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
@ -428,43 +444,48 @@ def decode_one_batch(
cut_ids = [c.supervisions[0].id for c in cuts] cut_ids = [c.supervisions[0].id for c in cuts]
batch_size = feature.size(0) batch_size = feature.size(0)
# get pre_text
if "pre_text" in batch["supervisions"] and params.use_pre_text: if "pre_text" in batch["supervisions"] and params.use_pre_text:
pre_texts = batch["supervisions"]["pre_text"] pre_texts = batch["supervisions"]["pre_text"]
pre_texts = [train_text_normalization(t) for t in pre_texts] pre_texts = [train_text_normalization(t) for t in pre_texts]
else: else:
pre_texts = ["" for _ in range(batch_size)] pre_texts = ["" for _ in range(batch_size)]
if params.use_ls_context_list: if params.use_ls_context_list and params.use_ls_test_set:
if params.biasing_level == "utterance": if params.biasing_level == "utterance":
pre_texts = [biasing_dict[id] for id in cut_ids] pre_texts = [biasing_dict[id] for id in cut_ids]
elif params.biasing_level == "Chapter": elif params.biasing_level == "Chapter":
chapter_ids = [c.split('-')[1] for c in cut_ids] chapter_ids = [c.split("-")[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids] pre_texts = [biasing_dict[id] for id in chapter_ids]
elif params.biasing_level == "Book": elif params.biasing_level == "Book":
chapter_ids = [c.split('-')[1] for c in cut_ids] chapter_ids = [c.split("-")[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids] pre_texts = [biasing_dict[id] for id in chapter_ids]
else:
raise ValueError(f"Unseen biasing level: {params.biasing_level}")
if params.pre_text_transform == "mixed-punc": if params.pre_text_transform == "mixed-punc":
pre_texts = [t.lower() for t in pre_texts] pre_texts = [t.lower() for t in pre_texts]
# get style_text # get style_text
if params.use_style_prompt: if params.use_style_prompt:
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)]) style_texts = batch["supervisions"].get(
"style_text", [fixed_sentence for _ in range(batch_size)]
)
style_texts = [train_text_normalization(t) for t in style_texts] style_texts = [train_text_normalization(t) for t in style_texts]
else: else:
style_texts = ["" for _ in range(batch_size)] # use empty string style_texts = ["" for _ in range(batch_size)] # use empty string
# Get the text embedding input # Get the text embedding
if params.use_pre_text or params.use_style_prompt: if params.use_pre_text or params.use_style_prompt:
# apply style transform to the pre_text and style_text # apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if not params.use_ls_context_list: if not params.use_ls_context_list:
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts] pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts]
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
if params.use_style_prompt: if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform) style_texts = _apply_style_transform(
style_texts, params.style_text_transform
)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
@ -477,7 +498,9 @@ def decode_one_batch(
device=device, device=device,
no_limit=True, no_limit=True,
) )
logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}") logging.info(
f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}"
)
memory, memory_key_padding_mask = model.encode_text( memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs, encoded_inputs=encoded_inputs,
@ -506,26 +529,12 @@ def decode_one_batch(
hyps = [] hyps = []
if ( if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
if memory is None or not params.use_context_embedding:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
else:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
context = model.joiner.context_proj(context) # (N,C)
hyp_tokens = greedy_search_batch_with_context(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context=context,
)
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
@ -545,26 +554,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
if memory is None or not params.use_context_embedding:
hyp = greedy_search( hyp = greedy_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
) )
else:
cur_context = context[i:i+1, :]
hyp = greedy_search_with_context(
model=model,
encoder_out=encoder_out_i,
context=cur_context,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -582,7 +576,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
tokenizer, tokenizer: spm.SentencePieceProcessor,
biasing_dict: Dict = None, biasing_dict: Dict = None,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
@ -598,6 +592,11 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
tokenizer:
Tokenizer for the text encoder
biasing_dict:
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
of biasing words (separated with space)
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
@ -627,7 +626,9 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format texts = batch["supervisions"][
"text"
] # By default, this should be in mixed-punc format
# the style of ref_text should match style_text # the style of ref_text should match style_text
texts = _apply_style_transform(texts, params.style_text_transform) texts = _apply_style_transform(texts, params.style_text_transform)
@ -637,9 +638,13 @@ def decode_dataset(
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set: if not params.use_ls_test_set:
try: try:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] book_names = [
except: cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]] ]
except AttributeError:
book_names = [
cut.id.split("/")[0] for cut in batch["supervisions"]["cut"]
]
else: else:
book_names = ["" for _ in cut_ids] book_names = ["" for _ in cut_ids]
@ -657,7 +662,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): for cut_id, book_name, hyp_words, ref_text in zip(
cut_ids, book_names, hyps, texts
):
ref_text = ref_text_normalization( ref_text = ref_text_normalization(
ref_text ref_text
) # remove full-width symbols & some book marks ) # remove full-width symbols & some book marks
@ -672,9 +679,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -705,7 +710,9 @@ def save_results(
if params.compute_CER: if params.compute_CER:
# Write CER statistics # Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" recog_path = (
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results, char_level=True) store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = ( errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
@ -723,9 +730,7 @@ def save_results(
logging.info("Wrote detailed CER stats to {}".format(errs_filename)) logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -740,9 +745,7 @@ def save_results(
if params.compute_CER: if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
for key, val in test_set_cers: for key, val in test_set_cers:
@ -771,9 +774,6 @@ def main():
"modified_beam_search", "modified_beam_search",
) )
if params.long_audio_recog:
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
else:
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
@ -792,22 +792,19 @@ def main():
params.suffix += f"-left-context-{params.left_context_frames}" params.suffix += f"-left-context-{params.left_context_frames}"
if "beam_search" in params.decoding_method: if "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_pre_text: if params.use_pre_text:
params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}" params.suffix += (
f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
)
if params.use_style_prompt: if params.use_style_prompt:
params.suffix += f"-style-prompt-{params.style_text_transform}" params.suffix += f"-style-prompt-{params.style_text_transform}"
if params.use_context_embedding:
params.suffix += f"-use-context-fuser"
if params.use_ls_context_list: if params.use_ls_context_list:
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list" params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
if params.biasing_level == "utterance" and params.ls_distractors: if params.biasing_level == "utterance" and params.ls_distractors:
@ -841,9 +838,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -870,9 +867,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -935,18 +932,15 @@ def main():
test_other_cuts = libriheavy.test_other_cuts() test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts()
npr1_dev_cuts = libriheavy.npr1_dev_cuts() test_clean_dl = libriheavy.valid_dataloaders(
npr1_test_cuts = libriheavy.npr1_test_cuts() test_clean_cuts, text_sampling_func=naive_triplet_text_sampling
)
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling) test_other_dl = libriheavy.valid_dataloaders(
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling) test_other_cuts, text_sampling_func=naive_triplet_text_sampling
)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling)
npr1_dev_dl = libriheavy.valid_dataloaders(npr1_dev_cuts, text_sampling_func=naive_triplet_text_sampling)
npr1_test_dl = libriheavy.valid_dataloaders(npr1_test_cuts, text_sampling_func=naive_triplet_text_sampling)
if params.use_ls_test_set: if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"] test_sets = ["ls-test-clean", "ls-test-other"]
@ -955,17 +949,19 @@ def main():
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_clean_dl, test_other_dl]
if params.long_audio_recog:
test_sets = ["long-audio"]
test_dl = [long_audio_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
if test_set == "ls-test-clean":
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
elif test_set == "ls-test-other":
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
else:
biasing_dict = None biasing_dict = None
if params.use_ls_context_list:
if test_set == "ls-test-clean":
biasing_dict = get_facebook_biasing_list(
test_set="test-clean",
num_distractors=params.ls_distractors,
)
elif test_set == "ls-test-other":
biasing_dict = get_facebook_biasing_list(
test_set="test-other",
num_distractors=params.ls_distractors,
)
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
@ -994,7 +990,9 @@ def main():
for item in results_dict[k]: for item in results_dict[k]:
id, ref, hyp = item id, ref, hyp = item
if params.use_ls_test_set: if params.use_ls_test_set:
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens hyp = (
" ".join(hyp).replace("-", " ").split()
) # handle the hypens
hyp = upper_only_alpha(" ".join(hyp)).split() hyp = upper_only_alpha(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp] hyp = [word_normalization(w.upper()) for w in hyp]
hyp = " ".join(hyp).split() hyp = " ".join(hyp).split()

View File

@ -22,24 +22,43 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless7/train.py \ (1) Non-streaming model, without context list
./zipformer_prompt_asr/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \ --subset medium \
--full-libri 1 \ --causal False \
--max-duration 550 --exp-dir zipformer_prompt_asr/exp \
--max-duration 1000 \
--memory-layer 0 \
--memory-dim 768 \
--text-encoder-type BERT \
--use-style-prompt True \
--use-context-list False
(2) Non-streaming model, with context list
./zipformer_prompt_asr/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--subset medium \
--causal False \
--exp-dir zipformer_prompt_asr/exp \
--max-duration 1000 \
--memory-layer 0 \
--memory-dim 768 \
--text-encoder-type BERT \
--use-style-prompt True \
--use-context-list True \
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
""" """
@ -61,30 +80,32 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from dataset2 import ( from dataset import (
triplet_text_sampling,
triplet_text_sampling_with_context_list,
naive_triplet_text_sampling, naive_triplet_text_sampling,
random_shuffle_subset, random_shuffle_subset,
joint_triplet_text_sampling, triplet_text_sampling,
triplet_style_text_sampling, triplet_text_sampling_with_context_list,
) )
from dataset import multi_ref_text_triplet_text_sampling
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model_with_BERT_with_style import PromptedTransducer from model_with_BERT import PromptedTransducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from text_normalization import (
lower_all_char,
lower_only_alpha,
train_text_normalization,
upper_all_char,
upper_only_alpha,
)
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from text_normalization import train_text_normalization, upper_only_alpha, lower_only_alpha, upper_all_char, lower_all_char
from zipformer import Zipformer2 from zipformer import Zipformer2
from icefall import diagnostics from icefall import diagnostics
@ -105,9 +126,7 @@ from icefall.utils import (
str2bool, str2bool,
) )
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
style_transforms = [ style_transforms = [
lambda x: x, # return it self lambda x: x, # return it self
@ -116,9 +135,11 @@ style_transforms = [
lower_all_char, lower_all_char,
] ]
def random_sampling(texts: List[str]) -> str: def random_sampling(texts: List[str]) -> str:
return random.choice(texts) return random.choice(texts)
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str: def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text # Randomly choose from the ground truth (mixed-cased trans) and the recog_text
i = random.randint(0, 1) i = random.randint(0, 1)
@ -130,6 +151,7 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
} }
return out return out
def get_first(texts: List[str], pre_texts: List[str]) -> str: def get_first(texts: List[str], pre_texts: List[str]) -> str:
out = { out = {
"text": texts[0], "text": texts[0],
@ -139,6 +161,7 @@ def get_first(texts: List[str], pre_texts: List[str]) -> str:
} }
return out return out
def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str: def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
out = { out = {
@ -149,6 +172,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
} }
return out return out
def get_adjusted_batch_count(params: AttributeDict) -> float: def get_adjusted_batch_count(params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference # returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count(). # duration. This is for purposes of set_batch_count().
@ -210,14 +234,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--memory-dropout-rate", "--memory-dropout-rate",
type=float, type=float,
default=0.05, default=0.05,
help="By which probability, dropout the memory when doing cross-attention." help="By which probability, dropout the memory when doing cross-attention.",
) )
parser.add_argument( parser.add_argument(
"--memory-layer", "--memory-layer",
type=int, type=int,
default=0, default=0,
help="Start doing cross-attention from which layer. Zero-indexed" help="Start doing cross-attention from which layer. Zero-indexed",
) )
parser.add_argument( parser.add_argument(
@ -285,8 +309,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -325,7 +348,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--text-encoder-adapter", "--text-encoder-adapter",
type=str2bool, type=str2bool,
default=False, default=False,
help="An adapter for pre-trained BERT" help="An adapter for pre-trained BERT",
) )
parser.add_argument( parser.add_argument(
@ -459,8 +482,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -680,7 +702,7 @@ class TextEmbedding(nn.Module):
layer1_channels: int = 256, layer1_channels: int = 256,
layer2_channels: int = 256, layer2_channels: int = 256,
bias: bool = True, bias: bool = True,
dropout: float = 0.1 dropout: float = 0.1,
): ):
super().__init__() super().__init__()
self.embed = nn.Embedding( self.embed = nn.Embedding(
@ -743,11 +765,13 @@ def get_text_encoder(params: AttributeDict) -> nn.Module:
# Return a text encoder # Return a text encoder
if params.text_encoder_type == "BERT": if params.text_encoder_type == "BERT":
from transformers import BertModel from transformers import BertModel
# This is a BERT-base-cased # This is a BERT-base-cased
logging.info("Loading pre-trained BERT-base-cased as text encoder") logging.info("Loading pre-trained BERT-base-cased as text encoder")
model = BertModel.from_pretrained("bert-base-cased") model = BertModel.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "DistilBERT": elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertModel from transformers import DistilBertModel
# This is a DistilBERT-base-cased # This is a DistilBERT-base-cased
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
model = DistilBertModel.from_pretrained("distilbert-base-cased") model = DistilBertModel.from_pretrained("distilbert-base-cased")
@ -756,20 +780,24 @@ def get_text_encoder(params: AttributeDict) -> nn.Module:
return model return model
def get_tokenizer(params: AttributeDict): def get_tokenizer(params: AttributeDict):
if params.text_encoder_type == "BERT": if params.text_encoder_type == "BERT":
from transformers import BertTokenizer from transformers import BertTokenizer
# This is a BERT-base-cased # This is a BERT-base-cased
tokenizer = BertTokenizer.from_pretrained('bert-base-cased') tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "DistilBERT": elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertTokenizer from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
else: else:
raise ValueError() raise ValueError()
return tokenizer return tokenizer
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2( encoder = Zipformer2(
output_downsampling_factor=2, output_downsampling_factor=2,
@ -812,7 +840,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
context_dim=4 * 768 if params.context_injection else -1, # the output dim of text encoder context_dim=4 * 768
if params.context_injection
else -1, # the output dim of text encoder
context_injection=params.context_injection, context_injection=params.context_injection,
) )
return joiner return joiner
@ -827,18 +857,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
if params.context_injection:
from context_fuser import ContextFuser, SelfAttContextFuser
context_fuser = SelfAttContextFuser(
embed_dim=768,
nhead=4,
context_dropout_rate=params.context_dropout_rate,
)
logging.info(f"Using context injection!")
logging.info(context_fuser)
else:
context_fuser = None
model = PromptedTransducer( model = PromptedTransducer(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
@ -851,12 +869,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
text_encoder_type=params.text_encoder_type, text_encoder_type=params.text_encoder_type,
text_encoder_adapter=params.text_encoder_adapter, text_encoder_adapter=params.text_encoder_adapter,
context_fuser=context_fuser, context_fuser=None,
) )
if params.text_encoder_adapter:
logging.info(f"Using adapter for BERT encoder")
logging.info(f"{model.text_encoder_adapter}")
return model return model
@ -978,13 +993,14 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def _encode_texts_as_bytes_with_tokenizer( def _encode_texts_as_bytes_with_tokenizer(
pre_texts: List[str], pre_texts: List[str],
style_texts: List[str], style_texts: List[str],
tokenizer, tokenizer,
device: torch.device, device: torch.device,
max_len: int = 500, max_len: int = 500,
no_limit: bool=False no_limit: bool = False,
) -> Tuple[Dict, Tensor]: ) -> Tuple[Dict, Tensor]:
""" """
Encode texts as bytes and then integer tensors. Encode texts as bytes and then integer tensors.
@ -998,11 +1014,13 @@ def _encode_texts_as_bytes_with_tokenizer(
else: else:
allowed_lens = [1000 - len(s) for s in style_texts] allowed_lens = [1000 - len(s) for s in style_texts]
truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)] truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)]
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)] combined_text = [
style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size)
]
encoded_style_texts = tokenizer( encoded_style_texts = tokenizer(
style_texts, style_texts,
return_tensors='pt', return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
return_length=True, return_length=True,
@ -1013,7 +1031,7 @@ def _encode_texts_as_bytes_with_tokenizer(
# Use tokenizer to prepare input for text encoder # Use tokenizer to prepare input for text encoder
encoded_inputs = tokenizer( encoded_inputs = tokenizer(
combined_text, combined_text,
return_tensors='pt', return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
return_length=True, return_length=True,
@ -1022,6 +1040,7 @@ def _encode_texts_as_bytes_with_tokenizer(
return encoded_inputs, style_lens return encoded_inputs, style_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -1048,11 +1067,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -1067,7 +1082,9 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
pre_texts = batch["supervisions"]["pre_text"] pre_texts = batch["supervisions"]["pre_text"]
style_texts = batch["supervisions"]["style_text"] # the style texts are in gt format style_texts = batch["supervisions"][
"style_text"
] # the style texts are in gt format
transform_ids = batch["supervisions"]["transform_ids"] transform_ids = batch["supervisions"]["transform_ids"]
# This is to replace full-width symbols with half-width symbols # This is to replace full-width symbols with half-width symbols
@ -1075,7 +1092,9 @@ def compute_loss(
pre_texts = [train_text_normalization(t) for t in pre_texts] pre_texts = [train_text_normalization(t) for t in pre_texts]
style_texts = [train_text_normalization(t) for t in style_texts] style_texts = [train_text_normalization(t) for t in style_texts]
y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space y = sp.encode(
texts, out_type=int
) # sp.encode treats consecutive space as a single space
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
if params.forced_upper_pre_text: if params.forced_upper_pre_text:
@ -1097,7 +1116,7 @@ def compute_loss(
style_texts = random_shuffle_subset( style_texts = random_shuffle_subset(
style_texts, style_texts,
p=params.style_text_shuffle_prob, p=params.style_text_shuffle_prob,
p_mask=params.prompt_mask_prob p_mask=params.prompt_mask_prob,
) )
assert len(transform_ids) == len(style_texts) assert len(transform_ids) == len(style_texts)
@ -1107,7 +1126,9 @@ def compute_loss(
style_texts[i] = style_transforms[t](style_texts[i]) style_texts[i] = style_transforms[t](style_texts[i])
if not params.use_style_prompt: if not params.use_style_prompt:
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt style_texts = [
"" for _ in style_texts
] # use empty string for style texts if don't use style prompt
if random.random() < 0.05: if random.random() < 0.05:
logging.info(f"Pre texts: {pre_texts[0]}") logging.info(f"Pre texts: {pre_texts[0]}")
@ -1157,9 +1178,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -1352,9 +1371,7 @@ def train_one_epoch(
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or ( if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
cur_grad_scale < 32.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
if not saved_bad_model: if not saved_bad_model:
@ -1376,11 +1393,7 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr:.2e}, "
+ ( + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
) )
if tb_writer is not None: if tb_writer is not None:
@ -1391,9 +1404,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", "train/grad_scale",
@ -1401,10 +1412,7 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
if ( if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -1455,8 +1463,12 @@ def run(rank, world_size, args):
if not params.use_style_prompt: if not params.use_style_prompt:
if params.pre_text_shuffle_prob == 0.0: if params.pre_text_shuffle_prob == 0.0:
logging.info(f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}") logging.info(
logging.info("If style prompt is not used, you should be careful when shuffling the pre_text within the same batch") f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}"
)
logging.info(
"If style prompt is not used, you should be careful when shuffling the pre_text within the same batch"
)
logging.info("Hard set this probability to 0.0!") logging.info("Hard set this probability to 0.0!")
params.pre_text_shuffle_prob = 0.0 params.pre_text_shuffle_prob = 0.0
@ -1504,7 +1516,9 @@ def run(rank, world_size, args):
if params.freeze_text_encoder: if params.freeze_text_encoder:
freeze_modules = ["text_encoder"] freeze_modules = ["text_encoder"]
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer") logging.info(
"Freeze the parameters of text encoder and don't include them in the optimizer"
)
else: else:
freeze_modules = [] freeze_modules = []
@ -1587,7 +1601,11 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
if params.use_context_list:
text_sampling_func = triplet_text_sampling_with_context_list
else:
text_sampling_func = triplet_text_sampling text_sampling_func = triplet_text_sampling
logging.info(f"Text sampling: {text_sampling_func}") logging.info(f"Text sampling: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders( train_dl = libriheavy.train_dataloaders(
@ -1599,18 +1617,17 @@ def run(rank, world_size, args):
# For fair comparison, use fixed sampling in valid dataloaders # For fair comparison, use fixed sampling in valid dataloaders
valid_cuts = libriheavy.dev_cuts() valid_cuts = libriheavy.dev_cuts()
valid_dl = libriheavy.valid_dataloaders( valid_dl = libriheavy.valid_dataloaders(
valid_cuts, valid_cuts, text_sampling_func=naive_triplet_text_sampling
text_sampling_func=naive_triplet_text_sampling
) )
# if not params.print_diagnostics: if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
# model=model, model=model,
# train_dl=train_dl, train_dl=train_dl,
# optimizer=optimizer, optimizer=optimizer,
# sp=sp, sp=sp,
# params=params, params=params,
# ) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints: