update training script

This commit is contained in:
marcoyang 2023-10-10 16:42:22 +08:00
parent 9f48d06581
commit 7a9c18fc79

View File

@ -22,7 +22,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For mix precision training:
(1) Non-streaming model, without context list
(1) Non-streaming model, **without** context list
./zipformer_prompt_asr/train.py \
--world-size 4 \
@ -34,12 +34,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir zipformer_prompt_asr/exp \
--max-duration 1000 \
--memory-layer 0 \
--memory-dim 768 \
--text-encoder-type BERT \
--text-encoder-dim 768 \
--use-style-prompt True \
--use-context-list False
(2) Non-streaming model, with context list
(2) Non-streaming model, **with** context list
./zipformer_prompt_asr/train.py \
--world-size 4 \
@ -51,10 +51,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir zipformer_prompt_asr/exp \
--max-duration 1000 \
--memory-layer 0 \
--memory-dim 768 \
--text-encoder-type BERT \
--text-encoder-dim 768 \
--use-style-prompt True \
--use-context-list True \
--top-k 10000 \
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
@ -64,6 +65,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
@ -134,22 +136,6 @@ style_transforms = [
]
def random_sampling(texts: List[str]) -> str:
return random.choice(texts)
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
i = random.randint(0, 1)
out = {
"text": texts[i],
"pre_text": pre_texts[i],
"style_text": "",
"transform_ids": 0,
}
return out
def get_first(texts: List[str], pre_texts: List[str]) -> str:
out = {
"text": texts[0],
@ -313,7 +299,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--causal",
type=str2bool,
default=True,
default=False,
help="If True, use causal version of model.",
)
@ -348,6 +334,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Type of the text encoder",
)
parser.add_argument(
"--text-encoder-dim",
type=int,
default=768,
help="Dimension of the text encoder",
)
parser.add_argument(
"--text-encoder-adapter",
type=str2bool,
@ -762,18 +755,30 @@ class TextEmbedding(nn.Module):
def get_text_encoder(params: AttributeDict) -> nn.Module:
# Return a text encoder
if params.text_encoder_type == "BERT":
if params.text_encoder_type == "BERT": # This is a BERT-base-cased
from transformers import BertModel
# This is a BERT-base-cased
logging.info("Loading pre-trained BERT-base-cased as text encoder")
model = BertModel.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertModel
if os.path.exists("data/models/bert-base-cased"):
model = BertModel.from_pretrained("data/models/bert-base-cased")
else:
model = BertModel.from_pretrained("bert-base-cased")
assert params.text_encoder_dim == 768
elif params.text_encoder_type == "BERT-large":
from transformers import BertModel
logging.info("Loading pre-trained BERT-large-uncased as text encoder")
if os.path.exists("data/models/bert-large-uncased"):
model = BertModel.from_pretrained("data/models/bert-large-uncased")
else:
model = BertModel.from_pretrained("bert-large-uncased")
assert params.text_encoder_dim == 1024
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertModel # This is a DistilBERT-base-cased
# This is a DistilBERT-base-cased
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
model = DistilBertModel.from_pretrained("distilbert-base-cased")
assert params.text_encoder_dim == 768
else:
raise ValueError()
@ -786,7 +791,18 @@ def get_tokenizer(params: AttributeDict):
from transformers import BertTokenizer
# This is a BERT-base-cased
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
if os.path.exists("data/models/bert-base-cased"):
tokenizer = BertTokenizer.from_pretrained("data/models/bert-base-cased")
else:
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "BERT-large":
from transformers import BertTokenizer
# This is a BERT-large-uncased
if os.path.exists("data/models/bert-large-uncased"):
tokenizer = BertTokenizer.from_pretrained("data/models/bert-large-uncased")
else:
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertTokenizer
@ -816,7 +832,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
memory_dim=768, # This is fixed as the BERT base model is 768-D
memory_dim=params.text_encoder_dim, # This is fixed as the BERT base model is 768-D
memory_layer=params.memory_layer,
memory_dropout_rate=params.memory_dropout_rate,
)
@ -848,13 +864,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
num_param = sum([p.numel() for p in text_encoder.parameters()])
logging.info(f"Num params in text encoder: {num_param}")
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
@ -1622,15 +1638,15 @@ def run(rank, world_size, args):
valid_cuts, text_sampling_func=naive_triplet_text_sampling
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# tokenizer=tokenizer,
# params=params,
# )
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
tokenizer=tokenizer,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: