diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder_with_style.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder_with_style.py index 09b11c61e..de1b6ab85 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder_with_style.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder_with_style.py @@ -61,14 +61,15 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriHeavyAsrDataModule -from dataset import ( +from dataset2 import ( triplet_text_sampling, - multi_ref_text_triplet_text_sampling, + triplet_text_sampling_with_context_list, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, triplet_style_text_sampling, ) +from dataset import multi_ref_text_triplet_text_sampling from decoder import Decoder from joiner import Joiner @@ -211,6 +212,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=0.05, help="By which probability, dropout the memory when doing cross-attention." ) + + parser.add_argument( + "--memory-layer", + type=int, + default=0, + help="Start doing cross-attention from which layer. Zero-indexed" + ) parser.add_argument( "--query-head-dim", @@ -558,6 +566,11 @@ def get_parser(): default=0.05, help="The probability of masking prompts", ) + parser.add_argument( + "--freeze-text-encoder", + type=str2bool, + default=True, + ) parser.add_argument( "--forced-upper-pre-text", @@ -731,12 +744,12 @@ def get_text_encoder(params: AttributeDict) -> nn.Module: if params.text_encoder_type == "BERT": from transformers import BertModel # This is a BERT-base-cased - logging.info("Loading pre-trained BERT-base-cased as text encaoder") + logging.info("Loading pre-trained BERT-base-cased as text encoder") model = BertModel.from_pretrained("bert-base-cased") elif params.text_encoder_type == "DistilBERT": from transformers import DistilBertModel # This is a DistilBERT-base-cased - logging.info("Loading pre-trained DistilBERT-base-cased as text encaoder") + logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") model = DistilBertModel.from_pretrained("distilbert-base-cased") else: raise ValueError() @@ -777,6 +790,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), memory_dim=768, # This is fixed as the BERT base model is 768-D + memory_layer=params.memory_layer, memory_dropout_rate=params.memory_dropout_rate, ) return encoder @@ -970,14 +984,19 @@ def _encode_texts_as_bytes_with_tokenizer( tokenizer, device: torch.device, max_len: int=500, + no_limit: bool=False ) -> Tuple[Dict, Tensor]: """ Encode texts as bytes and then integer tensors. Note that the style text will be added to the beginning of texts. """ batch_size = len(pre_texts) + max_len = min(max_len, 500) - allowed_lens = [1000 - len(s) for s in style_texts] + if no_limit: + allowed_lens = [5000 - len(s) for s in style_texts] + else: + allowed_lens = [1000 - len(s) for s in style_texts] truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)] combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)] @@ -987,7 +1006,7 @@ def _encode_texts_as_bytes_with_tokenizer( padding=True, truncation=True, return_length=True, - max_length=500, + max_length=max_len, ) style_lens = encoded_style_texts["length"].to(device) @@ -998,7 +1017,7 @@ def _encode_texts_as_bytes_with_tokenizer( padding=True, truncation=True, return_length=True, - max_length=500, + max_length=max_len, ).to(device) return encoded_inputs, style_lens @@ -1090,8 +1109,8 @@ def compute_loss( if not params.use_style_prompt: style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt - if random.random() < 0.03: - logging.info(f"Pre_texts: {pre_texts[0]}") + if random.random() < 0.05: + logging.info(f"Pre texts: {pre_texts[0]}") logging.info(f"Ref texts: {texts[0]}") logging.info(f"Style texts: {style_texts[0]}") @@ -1483,9 +1502,15 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) + if params.freeze_text_encoder: + freeze_modules = ["text_encoder"] + logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer") + else: + freeze_modules = [] + optimizer = ScaledAdam( get_parameter_groups_with_lrs( - model, lr=params.base_lr, include_names=True + model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules ), lr=params.base_lr, # should have no effect clipping_scale=2.0, @@ -1506,6 +1531,7 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: + args.max_duration = 100 opts = diagnostics.TensorDiagnosticOptions( 2 ** 22 ) # allow 4 megabytes per sub-module diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_subformer_with_style.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_subformer_with_style.py index abb7872dd..4660885c2 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_subformer_with_style.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_subformer_with_style.py @@ -61,8 +61,8 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriHeavyAsrDataModule -from dataset import triplet_text_sampling, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring -from dataset2 import triplet_text_sampling_with_context_list +from dataset import naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring +from dataset2 import triplet_text_sampling, triplet_text_sampling_with_context_list from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -1093,9 +1093,8 @@ def _encode_text_as_tokens( style_texts: List[str], bpe_model: spm.SentencePieceProcessor, device: torch.device, - max_tokens: int=500, + max_tokens: int=800, ) -> Tuple[Tensor, Tensor]: - max_tokens = min(500, max_tokens) batch_size = len(pre_texts) # encoded style texts @@ -1216,6 +1215,7 @@ def compute_loss( feature = feature.to(device) supervisions = batch["supervisions"] + cut_ids = [c.id for c in supervisions["cut"]] feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train @@ -1266,6 +1266,9 @@ def compute_loss( logging.info(f"Pre texts: {pre_texts[0]}") logging.info(f"Ref texts: {texts[0]}") logging.info(f"Style texts: {style_texts[0]}") + orig_style_texts = batch["supervisions"]["style_text"] + logging.info(f"Orig style texts: {orig_style_texts[0]}") + logging.info(f"Cut ID: {cut_ids[0]}") pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens( pre_texts=pre_texts, @@ -1752,7 +1755,7 @@ def run(rank, world_size, args): else: sampler_state_dict = None - text_sampling_func = triplet_text_sampling_with_context_list + text_sampling_func = triplet_text_sampling logging.info(f"Text sampling: {text_sampling_func}") train_dl = libriheavy.train_dataloaders(