diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index 5234cebf3..e253d1118 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -22,7 +22,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" # For mix precision training: -(1) Non-streaming model, without context list +(1) Non-streaming model, **without** context list ./zipformer_prompt_asr/train.py \ --world-size 4 \ @@ -34,12 +34,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir zipformer_prompt_asr/exp \ --max-duration 1000 \ --memory-layer 0 \ - --memory-dim 768 \ --text-encoder-type BERT \ + --text-encoder-dim 768 \ --use-style-prompt True \ --use-context-list False -(2) Non-streaming model, with context list +(2) Non-streaming model, **with** context list ./zipformer_prompt_asr/train.py \ --world-size 4 \ @@ -51,10 +51,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir zipformer_prompt_asr/exp \ --max-duration 1000 \ --memory-layer 0 \ - --memory-dim 768 \ --text-encoder-type BERT \ + --text-encoder-dim 768 \ --use-style-prompt True \ --use-context-list True \ + --top-k 10000 \ --rare-word-file data/context_biasing/small_rare_words_topk_10000.txt @@ -64,6 +65,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import copy import logging +import os import random import warnings from pathlib import Path @@ -134,22 +136,6 @@ style_transforms = [ ] -def random_sampling(texts: List[str]) -> str: - return random.choice(texts) - - -def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str: - # Randomly choose from the ground truth (mixed-cased trans) and the recog_text - i = random.randint(0, 1) - out = { - "text": texts[i], - "pre_text": pre_texts[i], - "style_text": "", - "transform_ids": 0, - } - return out - - def get_first(texts: List[str], pre_texts: List[str]) -> str: out = { "text": texts[0], @@ -313,7 +299,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--causal", type=str2bool, - default=True, + default=False, help="If True, use causal version of model.", ) @@ -348,6 +334,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Type of the text encoder", ) + parser.add_argument( + "--text-encoder-dim", + type=int, + default=768, + help="Dimension of the text encoder", + ) + parser.add_argument( "--text-encoder-adapter", type=str2bool, @@ -762,18 +755,30 @@ class TextEmbedding(nn.Module): def get_text_encoder(params: AttributeDict) -> nn.Module: # Return a text encoder - if params.text_encoder_type == "BERT": + if params.text_encoder_type == "BERT": # This is a BERT-base-cased from transformers import BertModel - # This is a BERT-base-cased logging.info("Loading pre-trained BERT-base-cased as text encoder") - model = BertModel.from_pretrained("bert-base-cased") - elif params.text_encoder_type == "DistilBERT": - from transformers import DistilBertModel + if os.path.exists("data/models/bert-base-cased"): + model = BertModel.from_pretrained("data/models/bert-base-cased") + else: + model = BertModel.from_pretrained("bert-base-cased") + assert params.text_encoder_dim == 768 + elif params.text_encoder_type == "BERT-large": + from transformers import BertModel + + logging.info("Loading pre-trained BERT-large-uncased as text encoder") + if os.path.exists("data/models/bert-large-uncased"): + model = BertModel.from_pretrained("data/models/bert-large-uncased") + else: + model = BertModel.from_pretrained("bert-large-uncased") + assert params.text_encoder_dim == 1024 + elif params.text_encoder_type == "DistilBERT": + from transformers import DistilBertModel # This is a DistilBERT-base-cased - # This is a DistilBERT-base-cased logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") model = DistilBertModel.from_pretrained("distilbert-base-cased") + assert params.text_encoder_dim == 768 else: raise ValueError() @@ -786,7 +791,18 @@ def get_tokenizer(params: AttributeDict): from transformers import BertTokenizer # This is a BERT-base-cased - tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + if os.path.exists("data/models/bert-base-cased"): + tokenizer = BertTokenizer.from_pretrained("data/models/bert-base-cased") + else: + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + elif params.text_encoder_type == "BERT-large": + from transformers import BertTokenizer + + # This is a BERT-large-uncased + if os.path.exists("data/models/bert-large-uncased"): + tokenizer = BertTokenizer.from_pretrained("data/models/bert-large-uncased") + else: + tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") elif params.text_encoder_type == "DistilBERT": from transformers import DistilBertTokenizer @@ -816,7 +832,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), - memory_dim=768, # This is fixed as the BERT base model is 768-D + memory_dim=params.text_encoder_dim, # This is fixed as the BERT base model is 768-D memory_layer=params.memory_layer, memory_dropout_rate=params.memory_dropout_rate, ) @@ -848,13 +864,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - text_encoder = get_text_encoder(params) # This should be a cased BERT base model num_param = sum([p.numel() for p in text_encoder.parameters()]) logging.info(f"Num params in text encoder: {num_param}") + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -1622,15 +1638,15 @@ def run(rank, world_size, args): valid_cuts, text_sampling_func=naive_triplet_text_sampling ) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # sp=sp, - # tokenizer=tokenizer, - # params=params, - # ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + tokenizer=tokenizer, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: