This commit is contained in:
marcoyang1998 2023-09-19 18:38:56 +08:00
parent bea1bd295f
commit 6579800720
3 changed files with 368 additions and 357 deletions

View File

@ -72,16 +72,12 @@ class LibriHeavyAsrDataModule:
self.args = args
if args.use_context_list:
from dataset2 import PromptASRDataset
assert args.rare_word_file is not None
with open(args.rare_word_file, "r") as f:
self.rare_word_list = (
f.read().lower().split()
) # Use lower-cased for easier style transform
else:
from dataset import PromptASRDataset
self.rare_word_list = None
@classmethod

View File

@ -20,22 +20,55 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--exp-dir ./zipformer_prompt_asr/exp \
--max-duration 1000 \
--decoding-method greedy_search
(2) modified beam search
./pruned_transducer_stateless7/decode.py \
./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--exp-dir ./zipformer_prompt_asr/exp \
--max-duration 1000 \
--decoding-method modified_beam_search \
--beam-size 4
(3) Decode LibriSpeech
./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer_prompt_asr/exp \
--max-duration 1000 \
--decoding-method modified_beam_search \
--use-ls-test-set True \
--beam-size 4
(4) Decode LibriSpeech + biasing list
biasing_list=100 # could also be 0
./zipformer_prompt_asr/decode_bert.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer_prompt_asr/exp \
--max-duration 1000 \
--decoding-method modified_beam_search \
--beam-size 4 \\
--use-ls-test-set True \
--use-ls-context-list True \
--biasing-level utterance \
--ls-distractors $biasing_list \
--post-normalization True \
--text-encoder-type BERT \
--max-prompt-lens 1000 \
--style-text-transform mixed-punc \
--pre-text-transform mixed-punc
"""
@ -45,40 +78,34 @@ import math
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Callable
from typing import Callable, Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import (
greedy_search,
greedy_search_with_context,
greedy_search_batch,
greedy_search_batch_with_context,
modified_beam_search,
)
from beam_search import greedy_search, greedy_search_batch, modified_beam_search
from dataset import naive_triplet_text_sampling, random_shuffle_subset
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
from ls_text_normalization import word_normalization
from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
upper_all_char,
lower_all_char,
lower_only_alpha,
ref_text_normalization,
remove_non_alphabetic,
train_text_normalization,
upper_all_char,
upper_only_alpha,
)
from train_bert_encoder_with_style import (
from train_bert_encoder import (
_encode_texts_as_bytes_with_tokenizer,
add_model_arguments,
get_params,
get_tokenizer,
get_transducer_model,
_encode_texts_as_bytes_with_tokenizer,
)
from transformers import BertModel, BertTokenizer
from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats
from icefall.checkpoint import (
average_checkpoints,
@ -87,15 +114,11 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
)
from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -262,26 +285,19 @@ def get_parser():
default=True,
help="Use pre-text is available during decoding",
)
parser.add_argument(
"--use-style-prompt",
type=str2bool,
default=True,
help="Use style prompt when evaluation"
help="Use style prompt when evaluation",
)
parser.add_argument(
"--max-prompt-lens",
type=int,
default=1000,
)
parser.add_argument(
"--use-context-embedding",
type=str2bool,
default=False,
help="Use context fuser when evaluation"
)
parser.add_argument(
"--post-normalization",
@ -289,70 +305,65 @@ def get_parser():
default=True,
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
)
parser.add_argument(
"--long-audio-recog",
type=str2bool,
default=False,
)
parser.add_argument(
"--compute-CER",
type=str2bool,
default=True,
default=False,
help="Reports CER. By default, only reports WER",
)
parser.add_argument(
"--style-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc",
help="The style of style prompt, i.e style_text"
help="The style of style prompt, i.e style_text",
)
parser.add_argument(
"--pre-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc",
help="The style of content prompt, i.e pre_text"
help="The style of content prompt, i.e pre_text",
)
parser.add_argument(
"--use-ls-test-set",
type=str2bool,
default=False,
help="Use librispeech test set for evaluation."
help="Use librispeech test set for evaluation.",
)
parser.add_argument(
"--use-ls-context-list",
type=str2bool,
default=False,
help="If use a fixed context list for LibriSpeech decoding"
help="If use a fixed context list for LibriSpeech decoding",
)
parser.add_argument(
"--biasing-level",
type=str,
default="utterance",
choices=["utterance", "Book", "Chapter"],
)
parser.add_argument(
"--ls-distractors",
type=int,
default=0,
help="The number of distractors into context list for LibriSpeech decoding"
help="The number of distractors into context list for LibriSpeech decoding",
)
add_model_arguments(parser)
return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
@ -372,13 +383,13 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]:
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
tokenizer,
tokenizer: spm.SentencePieceProcessor,
batch: dict,
biasing_dict: dict = None,
word_table: Optional[k2.SymbolTable] = None,
@ -401,10 +412,15 @@ def decode_one_batch(
The neural model.
sp:
The BPE model.
tokenizer:
Tokenizer for the text encoder
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
biasing_dict:
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
of biasing words (separated with space)
word_table:
The word symbol table.
decoding_graph:
@ -427,48 +443,53 @@ def decode_one_batch(
cuts = batch["supervisions"]["cut"]
cut_ids = [c.supervisions[0].id for c in cuts]
batch_size = feature.size(0)
# get pre_text
if "pre_text" in batch["supervisions"] and params.use_pre_text:
pre_texts = batch["supervisions"]["pre_text"]
pre_texts = [train_text_normalization(t) for t in pre_texts]
else:
pre_texts = ["" for _ in range(batch_size)]
if params.use_ls_context_list:
if params.use_ls_context_list and params.use_ls_test_set:
if params.biasing_level == "utterance":
pre_texts = [biasing_dict[id] for id in cut_ids]
elif params.biasing_level == "Chapter":
chapter_ids = [c.split('-')[1] for c in cut_ids]
chapter_ids = [c.split("-")[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids]
elif params.biasing_level == "Book":
chapter_ids = [c.split('-')[1] for c in cut_ids]
chapter_ids = [c.split("-")[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids]
else:
raise ValueError(f"Unseen biasing level: {params.biasing_level}")
if params.pre_text_transform == "mixed-punc":
pre_texts = [t.lower() for t in pre_texts]
# get style_text
if params.use_style_prompt:
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
style_texts = batch["supervisions"].get(
"style_text", [fixed_sentence for _ in range(batch_size)]
)
style_texts = [train_text_normalization(t) for t in style_texts]
else:
style_texts = ["" for _ in range(batch_size)] # use empty string
style_texts = ["" for _ in range(batch_size)] # use empty string
# Get the text embedding input
# Get the text embedding
if params.use_pre_text or params.use_style_prompt:
# apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if not params.use_ls_context_list:
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts]
if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
style_texts = _apply_style_transform(
style_texts, params.style_text_transform
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Use tokenizer to prepare input for text encoder
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
pre_texts=pre_texts,
@ -477,12 +498,14 @@ def decode_one_batch(
device=device,
no_limit=True,
)
logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}")
logging.info(
f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}"
)
memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs,
style_lens=style_lens,
) # (T,B,C)
) # (T,B,C)
else:
memory = None
memory_key_padding_mask = None
@ -506,26 +529,12 @@ def decode_one_batch(
hyps = []
if (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
if memory is None or not params.use_context_embedding:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
else:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
context = model.joiner.context_proj(context) # (N,C)
hyp_tokens = greedy_search_batch_with_context(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context=context,
)
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
@ -545,25 +554,10 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
if memory is None or not params.use_context_embedding:
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
else:
cur_context = context[i:i+1, :]
hyp = greedy_search_with_context(
model=model,
encoder_out=encoder_out_i,
context=cur_context,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
max_sym_per_frame=params.max_sym_per_frame,
)
else:
raise ValueError(
@ -582,7 +576,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
tokenizer,
tokenizer: spm.SentencePieceProcessor,
biasing_dict: Dict = None,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
@ -598,6 +592,11 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
tokenizer:
Tokenizer for the text encoder
biasing_dict:
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
of biasing words (separated with space)
word_table:
The word symbol table.
decoding_graph:
@ -627,19 +626,25 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
texts = batch["supervisions"][
"text"
] # By default, this should be in mixed-punc format
# the style of ref_text should match style_text
texts = _apply_style_transform(texts, params.style_text_transform)
texts = _apply_style_transform(texts, params.style_text_transform)
if params.use_style_prompt:
texts = _apply_style_transform(texts, params.style_text_transform)
texts = _apply_style_transform(texts, params.style_text_transform)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set:
try:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
except:
book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]]
book_names = [
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
]
except AttributeError:
book_names = [
cut.id.split("/")[0] for cut in batch["supervisions"]["cut"]
]
else:
book_names = ["" for _ in cut_ids]
@ -657,7 +662,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
for cut_id, book_name, hyp_words, ref_text in zip(
cut_ids, book_names, hyps, texts
):
ref_text = ref_text_normalization(
ref_text
) # remove full-width symbols & some book marks
@ -672,9 +679,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -705,7 +710,9 @@ def save_results(
if params.compute_CER:
# Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
recog_path = (
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
@ -723,9 +730,7 @@ def save_results(
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@ -740,9 +745,7 @@ def save_results(
if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_cers:
@ -770,11 +773,8 @@ def main():
"greedy_search",
"modified_beam_search",
)
if params.long_audio_recog:
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
else:
params.res_dir = params.exp_dir / params.decoding_method
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
@ -792,22 +792,19 @@ def main():
params.suffix += f"-left-context-{params.left_context_frames}"
if "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_pre_text:
params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
params.suffix += (
f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
)
if params.use_style_prompt:
params.suffix += f"-style-prompt-{params.style_text_transform}"
if params.use_context_embedding:
params.suffix += f"-use-context-fuser"
if params.use_ls_context_list:
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
if params.biasing_level == "utterance" and params.ls_distractors:
@ -841,9 +838,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -870,9 +867,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -935,18 +932,15 @@ def main():
test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts()
npr1_dev_cuts = libriheavy.npr1_dev_cuts()
npr1_test_cuts = libriheavy.npr1_test_cuts()
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling)
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling)
test_clean_dl = libriheavy.valid_dataloaders(
test_clean_cuts, text_sampling_func=naive_triplet_text_sampling
)
test_other_dl = libriheavy.valid_dataloaders(
test_other_cuts, text_sampling_func=naive_triplet_text_sampling
)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling)
npr1_dev_dl = libriheavy.valid_dataloaders(npr1_dev_cuts, text_sampling_func=naive_triplet_text_sampling)
npr1_test_dl = libriheavy.valid_dataloaders(npr1_test_cuts, text_sampling_func=naive_triplet_text_sampling)
if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"]
@ -954,19 +948,21 @@ def main():
else:
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
if params.long_audio_recog:
test_sets = ["long-audio"]
test_dl = [long_audio_dl]
for test_set, test_dl in zip(test_sets, test_dl):
if test_set == "ls-test-clean":
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
elif test_set == "ls-test-other":
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
else:
biasing_dict = None
biasing_dict = None
if params.use_ls_context_list:
if test_set == "ls-test-clean":
biasing_dict = get_facebook_biasing_list(
test_set="test-clean",
num_distractors=params.ls_distractors,
)
elif test_set == "ls-test-other":
biasing_dict = get_facebook_biasing_list(
test_set="test-other",
num_distractors=params.ls_distractors,
)
results_dict = decode_dataset(
dl=test_dl,
params=params,
@ -983,35 +979,37 @@ def main():
test_set_name=test_set,
results_dict=results_dict,
)
if params.post_normalization:
if "-post-normalization" not in params.suffix:
params.suffix += "-post-normalization"
new_res = {}
for k in results_dict:
new_ans = []
for item in results_dict[k]:
id, ref, hyp = item
if params.use_ls_test_set:
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
hyp = (
" ".join(hyp).replace("-", " ").split()
) # handle the hypens
hyp = upper_only_alpha(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp]
hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""]
ref = upper_only_alpha(" ".join(ref)).split()
ref = upper_only_alpha(" ".join(ref)).split()
else:
hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split()
new_ans.append((id,ref,hyp))
ref = upper_only_alpha(" ".join(ref)).split()
new_ans.append((id, ref, hyp))
new_res[k] = new_ans
save_results(
params=params,
test_set_name=test_set,
results_dict=new_res,
)
if params.suffix.endswith("-post-normalization"):
params.suffix = params.suffix.replace("-post-normalization", "")

View File

@ -22,24 +22,43 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless7/train.py \
(1) Non-streaming model, without context list
./zipformer_prompt_asr/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 550
--subset medium \
--causal False \
--exp-dir zipformer_prompt_asr/exp \
--max-duration 1000 \
--memory-layer 0 \
--memory-dim 768 \
--text-encoder-type BERT \
--use-style-prompt True \
--use-context-list False
(2) Non-streaming model, with context list
./zipformer_prompt_asr/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--subset medium \
--causal False \
--exp-dir zipformer_prompt_asr/exp \
--max-duration 1000 \
--memory-layer 0 \
--memory-dim 768 \
--text-encoder-type BERT \
--use-style-prompt True \
--use-context-list True \
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
"""
@ -61,30 +80,32 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule
from dataset2 import (
triplet_text_sampling,
triplet_text_sampling_with_context_list,
from dataset import (
naive_triplet_text_sampling,
random_shuffle_subset,
joint_triplet_text_sampling,
triplet_style_text_sampling,
triplet_text_sampling,
triplet_text_sampling_with_context_list,
)
from dataset import multi_ref_text_triplet_text_sampling
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model_with_BERT_with_style import PromptedTransducer
from model_with_BERT import PromptedTransducer
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR
from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR
from subsampling import Conv2dSubsampling
from text_normalization import (
lower_all_char,
lower_only_alpha,
train_text_normalization,
upper_all_char,
upper_only_alpha,
)
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from text_normalization import train_text_normalization, upper_only_alpha, lower_only_alpha, upper_all_char, lower_all_char
from zipformer import Zipformer2
from icefall import diagnostics
@ -105,20 +126,20 @@ from icefall.utils import (
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
style_transforms = [
lambda x: x, # return it self
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
lower_all_char,
]
def random_sampling(texts: List[str]) -> str:
return random.choice(texts)
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
i = random.randint(0, 1)
@ -130,6 +151,7 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
}
return out
def get_first(texts: List[str], pre_texts: List[str]) -> str:
out = {
"text": texts[0],
@ -139,6 +161,7 @@ def get_first(texts: List[str], pre_texts: List[str]) -> str:
}
return out
def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
out = {
@ -149,6 +172,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
}
return out
def get_adjusted_batch_count(params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count().
@ -205,19 +229,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="192,256,384,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--memory-dropout-rate",
type=float,
default=0.05,
help="By which probability, dropout the memory when doing cross-attention."
help="By which probability, dropout the memory when doing cross-attention.",
)
parser.add_argument(
"--memory-layer",
type=int,
default=0,
help="Start doing cross-attention from which layer. Zero-indexed"
help="Start doing cross-attention from which layer. Zero-indexed",
)
parser.add_argument(
@ -226,7 +250,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--value-head-dim",
type=str,
@ -280,13 +304,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
to this dimension before adding.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -312,29 +335,29 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
)
parser.add_argument(
"--text-encoder-type",
type=str,
default="BERT",
choices=["BERT","DistilBERT"],
choices=["BERT", "DistilBERT"],
help="Type of the text encoder",
)
parser.add_argument(
"--text-encoder-adapter",
type=str2bool,
default=False,
help="An adapter for pre-trained BERT"
help="An adapter for pre-trained BERT",
)
parser.add_argument(
"--context-injection",
type=str2bool,
default=False,
help="Inject context embedding into the joiner",
)
parser.add_argument(
"--context-dropout-rate",
type=float,
@ -459,8 +482,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -537,14 +559,14 @@ def get_parser():
default=False,
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-style-prompt",
type=str2bool,
default=True,
help="Whether to use style prompt.",
)
# arguments for using prompt
parser.add_argument(
"--pre-text-shuffle-prob",
@ -552,14 +574,14 @@ def get_parser():
default=0.05,
help="The proportion of pre_text to be shuffled with in a batch",
)
parser.add_argument(
"--style-text-shuffle-prob",
type=float,
default=0.2,
help="The proportion of style_text to be shuffled with in a batch",
)
parser.add_argument(
"--prompt-mask-prob",
type=float,
@ -571,14 +593,14 @@ def get_parser():
type=str2bool,
default=True,
)
parser.add_argument(
"--forced-upper-pre-text",
type=str2bool,
default=False,
help="Forced format of pre-text",
)
add_model_arguments(parser)
return parser
@ -674,25 +696,25 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
class TextEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int=256,
embedding_dim: int=256,
kernel_size: int=3,
num_embeddings: int = 256,
embedding_dim: int = 256,
kernel_size: int = 3,
layer1_channels: int = 256,
layer2_channels: int = 256,
bias: bool=True,
dropout: float = 0.1
bias: bool = True,
dropout: float = 0.1,
):
super().__init__()
self.embed = nn.Embedding(
num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes
embedding_dim=embedding_dim, #
embedding_dim=embedding_dim, #
)
assert embedding_dim == layer1_channels # for depth wise convolution
assert embedding_dim == layer1_channels # for depth wise convolution
self.conv = nn.Sequential(
nn.Conv1d(
embedding_dim,
layer1_channels, # depthwise convolution
layer1_channels, # depthwise convolution
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
@ -705,7 +727,7 @@ class TextEmbedding(nn.Module):
nn.Conv1d(
layer1_channels,
layer2_channels,
kernel_size=1, # pointwise convolution
kernel_size=1, # pointwise convolution
stride=1,
padding=0,
bias=True,
@ -713,10 +735,10 @@ class TextEmbedding(nn.Module):
Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
nn.ReLU(),
)
self.out_norm = BiasNorm(layer2_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(self, text: torch.Tensor) -> torch.Tensor:
"""Forward function of the text embedding
@ -725,51 +747,57 @@ class TextEmbedding(nn.Module):
Returns:
The embeddings of text (T,N,C)
"""
text = self.embed(text) # (T,N,C)
#src = text
text = text.permute(1,2,0) # (T,N,C) -> (N,C,T)
text = self.embed(text) # (T,N,C)
# src = text
text = text.permute(1, 2, 0) # (T,N,C) -> (N,C,T)
text = self.conv(text)
text = text.permute(2,0,1) # (N,C,T) -> (T,N,C)
#src = src + text
text = text.permute(2, 0, 1) # (N,C,T) -> (T,N,C)
# src = src + text
text = self.out_norm(text)
text = self.dropout(text)
return text
def get_text_encoder(params: AttributeDict) -> nn.Module:
# Return a text encoder
if params.text_encoder_type == "BERT":
from transformers import BertModel
# This is a BERT-base-cased
logging.info("Loading pre-trained BERT-base-cased as text encoder")
model = BertModel.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertModel
# This is a DistilBERT-base-cased
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
model = DistilBertModel.from_pretrained("distilbert-base-cased")
else:
raise ValueError()
return model
def get_tokenizer(params: AttributeDict):
if params.text_encoder_type == "BERT":
from transformers import BertTokenizer
# This is a BERT-base-cased
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
else:
raise ValueError()
return tokenizer
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2(
output_downsampling_factor=2,
@ -789,7 +817,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
memory_dim=768, # This is fixed as the BERT base model is 768-D
memory_dim=768, # This is fixed as the BERT base model is 768-D
memory_layer=params.memory_layer,
memory_dropout_rate=params.memory_dropout_rate,
)
@ -812,7 +840,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
context_dim=4 * 768 if params.context_injection else -1, # the output dim of text encoder
context_dim=4 * 768
if params.context_injection
else -1, # the output dim of text encoder
context_injection=params.context_injection,
)
return joiner
@ -821,23 +851,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
num_param = sum([p.numel() for p in text_encoder.parameters()])
logging.info(f"Num params in text encoder: {num_param}")
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
if params.context_injection:
from context_fuser import ContextFuser, SelfAttContextFuser
context_fuser = SelfAttContextFuser(
embed_dim=768,
nhead=4,
context_dropout_rate=params.context_dropout_rate,
)
logging.info(f"Using context injection!")
logging.info(context_fuser)
else:
context_fuser = None
model = PromptedTransducer(
encoder_embed=encoder_embed,
@ -851,12 +869,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
text_encoder_type=params.text_encoder_type,
text_encoder_adapter=params.text_encoder_adapter,
context_fuser=context_fuser,
context_fuser=None,
)
if params.text_encoder_adapter:
logging.info(f"Using adapter for BERT encoder")
logging.info(f"{model.text_encoder_adapter}")
return model
@ -978,13 +993,14 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def _encode_texts_as_bytes_with_tokenizer(
pre_texts: List[str],
pre_texts: List[str],
style_texts: List[str],
tokenizer,
device: torch.device,
max_len: int=500,
no_limit: bool=False
max_len: int = 500,
no_limit: bool = False,
) -> Tuple[Dict, Tensor]:
"""
Encode texts as bytes and then integer tensors.
@ -992,36 +1008,39 @@ def _encode_texts_as_bytes_with_tokenizer(
"""
batch_size = len(pre_texts)
max_len = min(max_len, 500)
if no_limit:
allowed_lens = [5000 - len(s) for s in style_texts]
else:
allowed_lens = [1000 - len(s) for s in style_texts]
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)]
combined_text = [
style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size)
]
encoded_style_texts = tokenizer(
style_texts,
return_tensors='pt',
return_tensors="pt",
padding=True,
truncation=True,
return_length=True,
max_length=max_len,
)
style_lens = encoded_style_texts["length"].to(device)
# Use tokenizer to prepare input for text encoder
encoded_inputs = tokenizer(
combined_text,
return_tensors='pt',
return_tensors="pt",
padding=True,
truncation=True,
return_length=True,
max_length=max_len,
).to(device)
return encoded_inputs, style_lens
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -1048,11 +1067,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -1067,20 +1082,24 @@ def compute_loss(
texts = batch["supervisions"]["text"]
pre_texts = batch["supervisions"]["pre_text"]
style_texts = batch["supervisions"]["style_text"] # the style texts are in gt format
style_texts = batch["supervisions"][
"style_text"
] # the style texts are in gt format
transform_ids = batch["supervisions"]["transform_ids"]
# This is to replace full-width symbols with half-width symbols
texts = [train_text_normalization(t) for t in texts]
pre_texts = [train_text_normalization(t) for t in pre_texts]
style_texts = [train_text_normalization(t) for t in style_texts]
y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space
y = sp.encode(
texts, out_type=int
) # sp.encode treats consecutive space as a single space
y = k2.RaggedTensor(y).to(device)
if params.forced_upper_pre_text:
pre_texts = [upper_only_alpha(p) for p in pre_texts]
# only shuffle the pre_text and style texts if during training, and use style prompt
if is_training:
# randomly shuffle&mask the pre_text
@ -1089,38 +1108,40 @@ def compute_loss(
p=params.pre_text_shuffle_prob,
p_mask=params.prompt_mask_prob,
)
if params.use_style_prompt:
if random.random() < 0.5:
if random.random() < 0.5:
# randomly shuffle the style_text
# now the style_texts are all in gt format
style_texts = random_shuffle_subset(
style_texts,
p=params.style_text_shuffle_prob,
p_mask=params.prompt_mask_prob
)
p_mask=params.prompt_mask_prob,
)
assert len(transform_ids) == len(style_texts)
for i in range(len(style_texts)):
t = transform_ids[i] # get the transform id
t = transform_ids[i] # get the transform id
style_texts[i] = style_transforms[t](style_texts[i])
if not params.use_style_prompt:
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
style_texts = [
"" for _ in style_texts
] # use empty string for style texts if don't use style prompt
if random.random() < 0.05:
logging.info(f"Pre texts: {pre_texts[0]}")
logging.info(f"Ref texts: {texts[0]}")
logging.info(f"Style texts: {style_texts[0]}")
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
pre_texts=pre_texts,
style_texts=style_texts,
tokenizer=tokenizer,
device=device,
)
if random.random() < 0.02:
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
@ -1157,9 +1178,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -1352,9 +1371,7 @@ def train_one_epoch(
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and batch_idx % 400 == 0
):
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
@ -1376,11 +1393,7 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
@ -1391,9 +1404,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale",
@ -1401,10 +1412,7 @@ def train_one_epoch(
params.batch_idx_train,
)
if (
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -1452,11 +1460,15 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if not params.use_style_prompt:
if params.pre_text_shuffle_prob == 0.0:
logging.info(f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}")
logging.info("If style prompt is not used, you should be careful when shuffling the pre_text within the same batch")
if params.pre_text_shuffle_prob == 0.0:
logging.info(
f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}"
)
logging.info(
"If style prompt is not used, you should be careful when shuffling the pre_text within the same batch"
)
logging.info("Hard set this probability to 0.0!")
params.pre_text_shuffle_prob = 0.0
@ -1504,10 +1516,12 @@ def run(rank, world_size, args):
if params.freeze_text_encoder:
freeze_modules = ["text_encoder"]
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
logging.info(
"Freeze the parameters of text encoder and don't include them in the optimizer"
)
else:
freeze_modules = []
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
@ -1533,7 +1547,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1543,7 +1557,7 @@ def run(rank, world_size, args):
libriheavy = LibriHeavyAsrDataModule(args)
train_cuts = libriheavy.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
@ -1586,10 +1600,14 @@ def run(rank, world_size, args):
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
text_sampling_func = triplet_text_sampling
if params.use_context_list:
text_sampling_func = triplet_text_sampling_with_context_list
else:
text_sampling_func = triplet_text_sampling
logging.info(f"Text sampling: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders(
train_cuts,
sampler_state_dict=sampler_state_dict,
@ -1599,18 +1617,17 @@ def run(rank, world_size, args):
# For fair comparison, use fixed sampling in valid dataloaders
valid_cuts = libriheavy.dev_cuts()
valid_dl = libriheavy.valid_dataloaders(
valid_cuts,
text_sampling_func=naive_triplet_text_sampling
valid_cuts, text_sampling_func=naive_triplet_text_sampling
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: