mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
update
This commit is contained in:
parent
bea1bd295f
commit
6579800720
@ -72,16 +72,12 @@ class LibriHeavyAsrDataModule:
|
||||
self.args = args
|
||||
|
||||
if args.use_context_list:
|
||||
from dataset2 import PromptASRDataset
|
||||
|
||||
assert args.rare_word_file is not None
|
||||
with open(args.rare_word_file, "r") as f:
|
||||
self.rare_word_list = (
|
||||
f.read().lower().split()
|
||||
) # Use lower-cased for easier style transform
|
||||
else:
|
||||
from dataset import PromptASRDataset
|
||||
|
||||
self.rare_word_list = None
|
||||
|
||||
@classmethod
|
||||
|
@ -20,22 +20,55 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./zipformer_prompt_asr/decode_bert.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./zipformer_prompt_asr/decode_bert.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) Decode LibriSpeech
|
||||
|
||||
./zipformer_prompt_asr/decode_bert.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--decoding-method modified_beam_search \
|
||||
--use-ls-test-set True \
|
||||
--beam-size 4
|
||||
|
||||
(4) Decode LibriSpeech + biasing list
|
||||
|
||||
biasing_list=100 # could also be 0
|
||||
|
||||
./zipformer_prompt_asr/decode_bert.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \\
|
||||
--use-ls-test-set True \
|
||||
--use-ls-context-list True \
|
||||
--biasing-level utterance \
|
||||
--ls-distractors $biasing_list \
|
||||
--post-normalization True \
|
||||
--text-encoder-type BERT \
|
||||
--max-prompt-lens 1000 \
|
||||
--style-text-transform mixed-punc \
|
||||
--pre-text-transform mixed-punc
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@ -45,40 +78,34 @@ import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Callable
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BertTokenizer, BertModel
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_with_context,
|
||||
greedy_search_batch,
|
||||
greedy_search_batch_with_context,
|
||||
modified_beam_search,
|
||||
)
|
||||
from beam_search import greedy_search, greedy_search_batch, modified_beam_search
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
|
||||
from ls_text_normalization import word_normalization
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from train_bert_encoder_with_style import (
|
||||
from train_bert_encoder import (
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
)
|
||||
from transformers import BertModel, BertTokenizer
|
||||
from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -87,15 +114,11 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -262,26 +285,19 @@ def get_parser():
|
||||
default=True,
|
||||
help="Use pre-text is available during decoding",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
help="Use style prompt when evaluation",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--max-prompt-lens",
|
||||
type=int,
|
||||
default=1000,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-context-embedding",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use context fuser when evaluation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
@ -289,70 +305,65 @@ def get_parser():
|
||||
default=True,
|
||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--long-audio-recog",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-CER",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
default=False,
|
||||
help="Reports CER. By default, only reports WER",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
help="The style of style prompt, i.e style_text",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
help="The style of content prompt, i.e pre_text",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-test-set",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use librispeech test set for evaluation."
|
||||
help="Use librispeech test set for evaluation.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-ls-context-list",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If use a fixed context list for LibriSpeech decoding"
|
||||
help="If use a fixed context list for LibriSpeech decoding",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--biasing-level",
|
||||
type=str,
|
||||
default="utterance",
|
||||
choices=["utterance", "Book", "Chapter"],
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--ls-distractors",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of distractors into context list for LibriSpeech decoding"
|
||||
help="The number of distractors into context list for LibriSpeech decoding",
|
||||
)
|
||||
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
@ -372,13 +383,13 @@ def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
tokenizer,
|
||||
tokenizer: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
biasing_dict: dict = None,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
@ -401,10 +412,15 @@ def decode_one_batch(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
tokenizer:
|
||||
Tokenizer for the text encoder
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
biasing_dict:
|
||||
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
|
||||
of biasing words (separated with space)
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
@ -427,48 +443,53 @@ def decode_one_batch(
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
cut_ids = [c.supervisions[0].id for c in cuts]
|
||||
batch_size = feature.size(0)
|
||||
|
||||
# get pre_text
|
||||
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"]["pre_text"]
|
||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
if params.use_ls_context_list:
|
||||
|
||||
if params.use_ls_context_list and params.use_ls_test_set:
|
||||
if params.biasing_level == "utterance":
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
elif params.biasing_level == "Chapter":
|
||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||
chapter_ids = [c.split("-")[1] for c in cut_ids]
|
||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||
elif params.biasing_level == "Book":
|
||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||
chapter_ids = [c.split("-")[1] for c in cut_ids]
|
||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||
else:
|
||||
raise ValueError(f"Unseen biasing level: {params.biasing_level}")
|
||||
if params.pre_text_transform == "mixed-punc":
|
||||
pre_texts = [t.lower() for t in pre_texts]
|
||||
|
||||
|
||||
# get style_text
|
||||
if params.use_style_prompt:
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
||||
style_texts = batch["supervisions"].get(
|
||||
"style_text", [fixed_sentence for _ in range(batch_size)]
|
||||
)
|
||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||
else:
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
|
||||
# Get the text embedding input
|
||||
# Get the text embedding
|
||||
if params.use_pre_text or params.use_style_prompt:
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
if not params.use_ls_context_list:
|
||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts]
|
||||
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
|
||||
style_texts = _apply_style_transform(
|
||||
style_texts, params.style_text_transform
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
|
||||
# Use tokenizer to prepare input for text encoder
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_texts,
|
||||
@ -477,12 +498,14 @@ def decode_one_batch(
|
||||
device=device,
|
||||
no_limit=True,
|
||||
)
|
||||
logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}")
|
||||
|
||||
logging.info(
|
||||
f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}"
|
||||
)
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
@ -506,26 +529,12 @@ def decode_one_batch(
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
if memory is None or not params.use_context_embedding:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
else:
|
||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
||||
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
|
||||
context = model.joiner.context_proj(context) # (N,C)
|
||||
hyp_tokens = greedy_search_batch_with_context(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context=context,
|
||||
)
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
@ -545,25 +554,10 @@ def decode_one_batch(
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
if memory is None or not params.use_context_embedding:
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
else:
|
||||
cur_context = context[i:i+1, :]
|
||||
hyp = greedy_search_with_context(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
context=cur_context,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -582,7 +576,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
tokenizer,
|
||||
tokenizer: spm.SentencePieceProcessor,
|
||||
biasing_dict: Dict = None,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
@ -598,6 +592,11 @@ def decode_dataset(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
tokenizer:
|
||||
Tokenizer for the text encoder
|
||||
biasing_dict:
|
||||
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
|
||||
of biasing words (separated with space)
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
@ -627,19 +626,25 @@ def decode_dataset(
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
|
||||
|
||||
texts = batch["supervisions"][
|
||||
"text"
|
||||
] # By default, this should be in mixed-punc format
|
||||
|
||||
# the style of ref_text should match style_text
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
if params.use_style_prompt:
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
if not params.use_ls_test_set:
|
||||
try:
|
||||
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
||||
except:
|
||||
book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]]
|
||||
book_names = [
|
||||
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
except AttributeError:
|
||||
book_names = [
|
||||
cut.id.split("/")[0] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
else:
|
||||
book_names = ["" for _ in cut_ids]
|
||||
|
||||
@ -657,7 +662,9 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
||||
for cut_id, book_name, hyp_words, ref_text in zip(
|
||||
cut_ids, book_names, hyps, texts
|
||||
):
|
||||
ref_text = ref_text_normalization(
|
||||
ref_text
|
||||
) # remove full-width symbols & some book marks
|
||||
@ -672,9 +679,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -705,7 +710,9 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
@ -723,9 +730,7 @@ def save_results(
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -740,9 +745,7 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
@ -770,11 +773,8 @@ def main():
|
||||
"greedy_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if params.long_audio_recog:
|
||||
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
|
||||
else:
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
@ -792,22 +792,19 @@ def main():
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_pre_text:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
|
||||
|
||||
params.suffix += (
|
||||
f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
|
||||
)
|
||||
|
||||
if params.use_style_prompt:
|
||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||
|
||||
if params.use_context_embedding:
|
||||
params.suffix += f"-use-context-fuser"
|
||||
|
||||
|
||||
if params.use_ls_context_list:
|
||||
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
|
||||
if params.biasing_level == "utterance" and params.ls_distractors:
|
||||
@ -841,9 +838,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -870,9 +867,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -935,18 +932,15 @@ def main():
|
||||
test_other_cuts = libriheavy.test_other_cuts()
|
||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||
|
||||
npr1_dev_cuts = libriheavy.npr1_dev_cuts()
|
||||
npr1_test_cuts = libriheavy.npr1_test_cuts()
|
||||
|
||||
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
test_clean_dl = libriheavy.valid_dataloaders(
|
||||
test_clean_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
test_other_dl = libriheavy.valid_dataloaders(
|
||||
test_other_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
npr1_dev_dl = libriheavy.valid_dataloaders(npr1_dev_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
npr1_test_dl = libriheavy.valid_dataloaders(npr1_test_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
|
||||
if params.use_ls_test_set:
|
||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||
@ -954,19 +948,21 @@ def main():
|
||||
else:
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
if params.long_audio_recog:
|
||||
test_sets = ["long-audio"]
|
||||
test_dl = [long_audio_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
if test_set == "ls-test-clean":
|
||||
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
|
||||
elif test_set == "ls-test-other":
|
||||
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
|
||||
else:
|
||||
biasing_dict = None
|
||||
|
||||
biasing_dict = None
|
||||
if params.use_ls_context_list:
|
||||
if test_set == "ls-test-clean":
|
||||
biasing_dict = get_facebook_biasing_list(
|
||||
test_set="test-clean",
|
||||
num_distractors=params.ls_distractors,
|
||||
)
|
||||
elif test_set == "ls-test-other":
|
||||
biasing_dict = get_facebook_biasing_list(
|
||||
test_set="test-other",
|
||||
num_distractors=params.ls_distractors,
|
||||
)
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
@ -983,35 +979,37 @@ def main():
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
|
||||
if params.post_normalization:
|
||||
if "-post-normalization" not in params.suffix:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
|
||||
new_res = {}
|
||||
for k in results_dict:
|
||||
new_ans = []
|
||||
for item in results_dict[k]:
|
||||
id, ref, hyp = item
|
||||
if params.use_ls_test_set:
|
||||
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
|
||||
hyp = (
|
||||
" ".join(hyp).replace("-", " ").split()
|
||||
) # handle the hypens
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
||||
hyp = " ".join(hyp).split()
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
else:
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_ans.append((id,ref,hyp))
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_ans.append((id, ref, hyp))
|
||||
new_res[k] = new_ans
|
||||
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=new_res,
|
||||
)
|
||||
|
||||
|
||||
if params.suffix.endswith("-post-normalization"):
|
||||
params.suffix = params.suffix.replace("-post-normalization", "")
|
||||
|
||||
|
@ -22,24 +22,43 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
(1) Non-streaming model, without context list
|
||||
|
||||
./zipformer_prompt_asr/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
--subset medium \
|
||||
--causal False \
|
||||
--exp-dir zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--memory-layer 0 \
|
||||
--memory-dim 768 \
|
||||
--text-encoder-type BERT \
|
||||
--use-style-prompt True \
|
||||
--use-context-list False
|
||||
|
||||
(2) Non-streaming model, with context list
|
||||
|
||||
./zipformer_prompt_asr/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--subset medium \
|
||||
--causal False \
|
||||
--exp-dir zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--memory-layer 0 \
|
||||
--memory-dim 768 \
|
||||
--text-encoder-type BERT \
|
||||
--use-style-prompt True \
|
||||
--use-context-list True \
|
||||
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@ -61,30 +80,32 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from dataset2 import (
|
||||
triplet_text_sampling,
|
||||
triplet_text_sampling_with_context_list,
|
||||
from dataset import (
|
||||
naive_triplet_text_sampling,
|
||||
random_shuffle_subset,
|
||||
joint_triplet_text_sampling,
|
||||
triplet_style_text_sampling,
|
||||
triplet_text_sampling,
|
||||
triplet_text_sampling_with_context_list,
|
||||
)
|
||||
from dataset import multi_ref_text_triplet_text_sampling
|
||||
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model_with_BERT_with_style import PromptedTransducer
|
||||
from model_with_BERT import PromptedTransducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR
|
||||
from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR
|
||||
from subsampling import Conv2dSubsampling
|
||||
from text_normalization import (
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from text_normalization import train_text_normalization, upper_only_alpha, lower_only_alpha, upper_all_char, lower_all_char
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall import diagnostics
|
||||
@ -105,20 +126,20 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
style_transforms = [
|
||||
lambda x: x, # return it self
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
|
||||
def random_sampling(texts: List[str]) -> str:
|
||||
return random.choice(texts)
|
||||
|
||||
|
||||
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
||||
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
|
||||
i = random.randint(0, 1)
|
||||
@ -130,6 +151,7 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
||||
out = {
|
||||
"text": texts[0],
|
||||
@ -139,6 +161,7 @@ def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
out = {
|
||||
@ -149,6 +172,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||
# returns the number of batches we would have used so far if we had used the reference
|
||||
# duration. This is for purposes of set_batch_count().
|
||||
@ -205,19 +229,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default="192,256,384,512,384,256",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--memory-dropout-rate",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="By which probability, dropout the memory when doing cross-attention."
|
||||
help="By which probability, dropout the memory when doing cross-attention.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--memory-layer",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Start doing cross-attention from which layer. Zero-indexed"
|
||||
help="Start doing cross-attention from which layer. Zero-indexed",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -226,7 +250,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default="32",
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--value-head-dim",
|
||||
type=str,
|
||||
@ -280,13 +304,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
to this dimension before adding.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -312,29 +335,29 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"be converted to a number of chunks. If splitting into chunks, "
|
||||
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-type",
|
||||
type=str,
|
||||
default="BERT",
|
||||
choices=["BERT","DistilBERT"],
|
||||
choices=["BERT", "DistilBERT"],
|
||||
help="Type of the text encoder",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-adapter",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="An adapter for pre-trained BERT"
|
||||
help="An adapter for pre-trained BERT",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--context-injection",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Inject context embedding into the joiner",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--context-dropout-rate",
|
||||
type=float,
|
||||
@ -459,8 +482,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -537,14 +559,14 @@ def get_parser():
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use style prompt.",
|
||||
)
|
||||
|
||||
|
||||
# arguments for using prompt
|
||||
parser.add_argument(
|
||||
"--pre-text-shuffle-prob",
|
||||
@ -552,14 +574,14 @@ def get_parser():
|
||||
default=0.05,
|
||||
help="The proportion of pre_text to be shuffled with in a batch",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-shuffle-prob",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="The proportion of style_text to be shuffled with in a batch",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-mask-prob",
|
||||
type=float,
|
||||
@ -571,14 +593,14 @@ def get_parser():
|
||||
type=str2bool,
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--forced-upper-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Forced format of pre-text",
|
||||
)
|
||||
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -674,25 +696,25 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int=256,
|
||||
embedding_dim: int=256,
|
||||
kernel_size: int=3,
|
||||
num_embeddings: int = 256,
|
||||
embedding_dim: int = 256,
|
||||
kernel_size: int = 3,
|
||||
layer1_channels: int = 256,
|
||||
layer2_channels: int = 256,
|
||||
bias: bool=True,
|
||||
dropout: float = 0.1
|
||||
bias: bool = True,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes
|
||||
embedding_dim=embedding_dim, #
|
||||
embedding_dim=embedding_dim, #
|
||||
)
|
||||
|
||||
assert embedding_dim == layer1_channels # for depth wise convolution
|
||||
|
||||
assert embedding_dim == layer1_channels # for depth wise convolution
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
embedding_dim,
|
||||
layer1_channels, # depthwise convolution
|
||||
layer1_channels, # depthwise convolution
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
@ -705,7 +727,7 @@ class TextEmbedding(nn.Module):
|
||||
nn.Conv1d(
|
||||
layer1_channels,
|
||||
layer2_channels,
|
||||
kernel_size=1, # pointwise convolution
|
||||
kernel_size=1, # pointwise convolution
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
@ -713,10 +735,10 @@ class TextEmbedding(nn.Module):
|
||||
Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
|
||||
self.out_norm = BiasNorm(layer2_channels)
|
||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||
|
||||
|
||||
def forward(self, text: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of the text embedding
|
||||
|
||||
@ -725,51 +747,57 @@ class TextEmbedding(nn.Module):
|
||||
Returns:
|
||||
The embeddings of text (T,N,C)
|
||||
"""
|
||||
text = self.embed(text) # (T,N,C)
|
||||
|
||||
#src = text
|
||||
text = text.permute(1,2,0) # (T,N,C) -> (N,C,T)
|
||||
text = self.embed(text) # (T,N,C)
|
||||
|
||||
# src = text
|
||||
text = text.permute(1, 2, 0) # (T,N,C) -> (N,C,T)
|
||||
text = self.conv(text)
|
||||
text = text.permute(2,0,1) # (N,C,T) -> (T,N,C)
|
||||
#src = src + text
|
||||
|
||||
text = text.permute(2, 0, 1) # (N,C,T) -> (T,N,C)
|
||||
# src = src + text
|
||||
|
||||
text = self.out_norm(text)
|
||||
text = self.dropout(text)
|
||||
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
def get_text_encoder(params: AttributeDict) -> nn.Module:
|
||||
# Return a text encoder
|
||||
if params.text_encoder_type == "BERT":
|
||||
from transformers import BertModel
|
||||
|
||||
# This is a BERT-base-cased
|
||||
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
||||
model = BertModel.from_pretrained("bert-base-cased")
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertModel
|
||||
|
||||
# This is a DistilBERT-base-cased
|
||||
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
||||
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_tokenizer(params: AttributeDict):
|
||||
|
||||
|
||||
if params.text_encoder_type == "BERT":
|
||||
from transformers import BertTokenizer
|
||||
|
||||
# This is a BERT-base-cased
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertTokenizer
|
||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
||||
|
||||
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Zipformer2(
|
||||
output_downsampling_factor=2,
|
||||
@ -789,7 +817,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
||||
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
||||
memory_layer=params.memory_layer,
|
||||
memory_dropout_rate=params.memory_dropout_rate,
|
||||
)
|
||||
@ -812,7 +840,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
context_dim=4 * 768 if params.context_injection else -1, # the output dim of text encoder
|
||||
context_dim=4 * 768
|
||||
if params.context_injection
|
||||
else -1, # the output dim of text encoder
|
||||
context_injection=params.context_injection,
|
||||
)
|
||||
return joiner
|
||||
@ -821,23 +851,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
||||
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
||||
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
||||
logging.info(f"Num params in text encoder: {num_param}")
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
if params.context_injection:
|
||||
from context_fuser import ContextFuser, SelfAttContextFuser
|
||||
context_fuser = SelfAttContextFuser(
|
||||
embed_dim=768,
|
||||
nhead=4,
|
||||
context_dropout_rate=params.context_dropout_rate,
|
||||
)
|
||||
logging.info(f"Using context injection!")
|
||||
logging.info(context_fuser)
|
||||
else:
|
||||
context_fuser = None
|
||||
|
||||
model = PromptedTransducer(
|
||||
encoder_embed=encoder_embed,
|
||||
@ -851,12 +869,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
vocab_size=params.vocab_size,
|
||||
text_encoder_type=params.text_encoder_type,
|
||||
text_encoder_adapter=params.text_encoder_adapter,
|
||||
context_fuser=context_fuser,
|
||||
context_fuser=None,
|
||||
)
|
||||
|
||||
if params.text_encoder_adapter:
|
||||
logging.info(f"Using adapter for BERT encoder")
|
||||
logging.info(f"{model.text_encoder_adapter}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ -978,13 +993,14 @@ def save_checkpoint(
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts: List[str],
|
||||
pre_texts: List[str],
|
||||
style_texts: List[str],
|
||||
tokenizer,
|
||||
device: torch.device,
|
||||
max_len: int=500,
|
||||
no_limit: bool=False
|
||||
max_len: int = 500,
|
||||
no_limit: bool = False,
|
||||
) -> Tuple[Dict, Tensor]:
|
||||
"""
|
||||
Encode texts as bytes and then integer tensors.
|
||||
@ -992,36 +1008,39 @@ def _encode_texts_as_bytes_with_tokenizer(
|
||||
"""
|
||||
batch_size = len(pre_texts)
|
||||
max_len = min(max_len, 500)
|
||||
|
||||
|
||||
if no_limit:
|
||||
allowed_lens = [5000 - len(s) for s in style_texts]
|
||||
else:
|
||||
allowed_lens = [1000 - len(s) for s in style_texts]
|
||||
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
|
||||
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
|
||||
|
||||
truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)]
|
||||
combined_text = [
|
||||
style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size)
|
||||
]
|
||||
|
||||
encoded_style_texts = tokenizer(
|
||||
style_texts,
|
||||
return_tensors='pt',
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
max_length=max_len,
|
||||
)
|
||||
style_lens = encoded_style_texts["length"].to(device)
|
||||
|
||||
|
||||
# Use tokenizer to prepare input for text encoder
|
||||
encoded_inputs = tokenizer(
|
||||
combined_text,
|
||||
return_tensors='pt',
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
max_length=max_len,
|
||||
).to(device)
|
||||
|
||||
|
||||
return encoded_inputs, style_lens
|
||||
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -1048,11 +1067,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -1067,20 +1082,24 @@ def compute_loss(
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
pre_texts = batch["supervisions"]["pre_text"]
|
||||
style_texts = batch["supervisions"]["style_text"] # the style texts are in gt format
|
||||
style_texts = batch["supervisions"][
|
||||
"style_text"
|
||||
] # the style texts are in gt format
|
||||
transform_ids = batch["supervisions"]["transform_ids"]
|
||||
|
||||
|
||||
# This is to replace full-width symbols with half-width symbols
|
||||
texts = [train_text_normalization(t) for t in texts]
|
||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||
|
||||
y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space
|
||||
|
||||
y = sp.encode(
|
||||
texts, out_type=int
|
||||
) # sp.encode treats consecutive space as a single space
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
|
||||
if params.forced_upper_pre_text:
|
||||
pre_texts = [upper_only_alpha(p) for p in pre_texts]
|
||||
|
||||
|
||||
# only shuffle the pre_text and style texts if during training, and use style prompt
|
||||
if is_training:
|
||||
# randomly shuffle&mask the pre_text
|
||||
@ -1089,38 +1108,40 @@ def compute_loss(
|
||||
p=params.pre_text_shuffle_prob,
|
||||
p_mask=params.prompt_mask_prob,
|
||||
)
|
||||
|
||||
|
||||
if params.use_style_prompt:
|
||||
if random.random() < 0.5:
|
||||
if random.random() < 0.5:
|
||||
# randomly shuffle the style_text
|
||||
# now the style_texts are all in gt format
|
||||
style_texts = random_shuffle_subset(
|
||||
style_texts,
|
||||
p=params.style_text_shuffle_prob,
|
||||
p_mask=params.prompt_mask_prob
|
||||
)
|
||||
|
||||
p_mask=params.prompt_mask_prob,
|
||||
)
|
||||
|
||||
assert len(transform_ids) == len(style_texts)
|
||||
|
||||
|
||||
for i in range(len(style_texts)):
|
||||
t = transform_ids[i] # get the transform id
|
||||
t = transform_ids[i] # get the transform id
|
||||
style_texts[i] = style_transforms[t](style_texts[i])
|
||||
|
||||
if not params.use_style_prompt:
|
||||
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
|
||||
|
||||
style_texts = [
|
||||
"" for _ in style_texts
|
||||
] # use empty string for style texts if don't use style prompt
|
||||
|
||||
if random.random() < 0.05:
|
||||
logging.info(f"Pre texts: {pre_texts[0]}")
|
||||
logging.info(f"Ref texts: {texts[0]}")
|
||||
logging.info(f"Style texts: {style_texts[0]}")
|
||||
|
||||
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_texts,
|
||||
style_texts=style_texts,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
if random.random() < 0.02:
|
||||
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
|
||||
|
||||
@ -1157,9 +1178,7 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -1352,9 +1371,7 @@ def train_one_epoch(
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and batch_idx % 400 == 0
|
||||
):
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
@ -1376,11 +1393,7 @@ def train_one_epoch(
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (
|
||||
f"grad_scale: {scaler._scale.item()}"
|
||||
if params.use_fp16
|
||||
else ""
|
||||
)
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -1391,9 +1404,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale",
|
||||
@ -1401,10 +1412,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
if (
|
||||
batch_idx % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -1452,11 +1460,15 @@ def run(rank, world_size, args):
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
|
||||
if not params.use_style_prompt:
|
||||
if params.pre_text_shuffle_prob == 0.0:
|
||||
logging.info(f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}")
|
||||
logging.info("If style prompt is not used, you should be careful when shuffling the pre_text within the same batch")
|
||||
if params.pre_text_shuffle_prob == 0.0:
|
||||
logging.info(
|
||||
f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}"
|
||||
)
|
||||
logging.info(
|
||||
"If style prompt is not used, you should be careful when shuffling the pre_text within the same batch"
|
||||
)
|
||||
logging.info("Hard set this probability to 0.0!")
|
||||
params.pre_text_shuffle_prob = 0.0
|
||||
|
||||
@ -1504,10 +1516,12 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.freeze_text_encoder:
|
||||
freeze_modules = ["text_encoder"]
|
||||
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
|
||||
logging.info(
|
||||
"Freeze the parameters of text encoder and don't include them in the optimizer"
|
||||
)
|
||||
else:
|
||||
freeze_modules = []
|
||||
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
|
||||
@ -1533,7 +1547,7 @@ def run(rank, world_size, args):
|
||||
if params.print_diagnostics:
|
||||
args.max_duration = 100
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
@ -1543,7 +1557,7 @@ def run(rank, world_size, args):
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
train_cuts = libriheavy.train_cuts()
|
||||
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
@ -1586,10 +1600,14 @@ def run(rank, world_size, args):
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
text_sampling_func = triplet_text_sampling
|
||||
|
||||
if params.use_context_list:
|
||||
text_sampling_func = triplet_text_sampling_with_context_list
|
||||
else:
|
||||
text_sampling_func = triplet_text_sampling
|
||||
|
||||
logging.info(f"Text sampling: {text_sampling_func}")
|
||||
|
||||
|
||||
train_dl = libriheavy.train_dataloaders(
|
||||
train_cuts,
|
||||
sampler_state_dict=sampler_state_dict,
|
||||
@ -1599,18 +1617,17 @@ def run(rank, world_size, args):
|
||||
# For fair comparison, use fixed sampling in valid dataloaders
|
||||
valid_cuts = libriheavy.dev_cuts()
|
||||
valid_dl = libriheavy.valid_dataloaders(
|
||||
valid_cuts,
|
||||
text_sampling_func=naive_triplet_text_sampling
|
||||
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user