mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
update training script
This commit is contained in:
parent
9f48d06581
commit
7a9c18fc79
@ -22,7 +22,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
(1) Non-streaming model, without context list
|
||||
(1) Non-streaming model, **without** context list
|
||||
|
||||
./zipformer_prompt_asr/train.py \
|
||||
--world-size 4 \
|
||||
@ -34,12 +34,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--exp-dir zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--memory-layer 0 \
|
||||
--memory-dim 768 \
|
||||
--text-encoder-type BERT \
|
||||
--text-encoder-dim 768 \
|
||||
--use-style-prompt True \
|
||||
--use-context-list False
|
||||
|
||||
(2) Non-streaming model, with context list
|
||||
(2) Non-streaming model, **with** context list
|
||||
|
||||
./zipformer_prompt_asr/train.py \
|
||||
--world-size 4 \
|
||||
@ -51,10 +51,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--exp-dir zipformer_prompt_asr/exp \
|
||||
--max-duration 1000 \
|
||||
--memory-layer 0 \
|
||||
--memory-dim 768 \
|
||||
--text-encoder-type BERT \
|
||||
--text-encoder-dim 768 \
|
||||
--use-style-prompt True \
|
||||
--use-context-list True \
|
||||
--top-k 10000 \
|
||||
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
|
||||
|
||||
|
||||
@ -64,6 +65,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@ -134,22 +136,6 @@ style_transforms = [
|
||||
]
|
||||
|
||||
|
||||
def random_sampling(texts: List[str]) -> str:
|
||||
return random.choice(texts)
|
||||
|
||||
|
||||
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
||||
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
|
||||
i = random.randint(0, 1)
|
||||
out = {
|
||||
"text": texts[i],
|
||||
"pre_text": pre_texts[i],
|
||||
"style_text": "",
|
||||
"transform_ids": 0,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
||||
out = {
|
||||
"text": texts[0],
|
||||
@ -313,7 +299,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--causal",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
default=False,
|
||||
help="If True, use causal version of model.",
|
||||
)
|
||||
|
||||
@ -348,6 +334,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Type of the text encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-dim",
|
||||
type=int,
|
||||
default=768,
|
||||
help="Dimension of the text encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-adapter",
|
||||
type=str2bool,
|
||||
@ -762,18 +755,30 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
def get_text_encoder(params: AttributeDict) -> nn.Module:
|
||||
# Return a text encoder
|
||||
if params.text_encoder_type == "BERT":
|
||||
if params.text_encoder_type == "BERT": # This is a BERT-base-cased
|
||||
from transformers import BertModel
|
||||
|
||||
# This is a BERT-base-cased
|
||||
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
||||
model = BertModel.from_pretrained("bert-base-cased")
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertModel
|
||||
if os.path.exists("data/models/bert-base-cased"):
|
||||
model = BertModel.from_pretrained("data/models/bert-base-cased")
|
||||
else:
|
||||
model = BertModel.from_pretrained("bert-base-cased")
|
||||
assert params.text_encoder_dim == 768
|
||||
elif params.text_encoder_type == "BERT-large":
|
||||
from transformers import BertModel
|
||||
|
||||
logging.info("Loading pre-trained BERT-large-uncased as text encoder")
|
||||
if os.path.exists("data/models/bert-large-uncased"):
|
||||
model = BertModel.from_pretrained("data/models/bert-large-uncased")
|
||||
else:
|
||||
model = BertModel.from_pretrained("bert-large-uncased")
|
||||
assert params.text_encoder_dim == 1024
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertModel # This is a DistilBERT-base-cased
|
||||
|
||||
# This is a DistilBERT-base-cased
|
||||
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
||||
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
||||
assert params.text_encoder_dim == 768
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@ -786,7 +791,18 @@ def get_tokenizer(params: AttributeDict):
|
||||
from transformers import BertTokenizer
|
||||
|
||||
# This is a BERT-base-cased
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
if os.path.exists("data/models/bert-base-cased"):
|
||||
tokenizer = BertTokenizer.from_pretrained("data/models/bert-base-cased")
|
||||
else:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
elif params.text_encoder_type == "BERT-large":
|
||||
from transformers import BertTokenizer
|
||||
|
||||
# This is a BERT-large-uncased
|
||||
if os.path.exists("data/models/bert-large-uncased"):
|
||||
tokenizer = BertTokenizer.from_pretrained("data/models/bert-large-uncased")
|
||||
else:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertTokenizer
|
||||
|
||||
@ -816,7 +832,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
||||
memory_dim=params.text_encoder_dim, # This is fixed as the BERT base model is 768-D
|
||||
memory_layer=params.memory_layer,
|
||||
memory_dropout_rate=params.memory_dropout_rate,
|
||||
)
|
||||
@ -848,13 +864,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
|
||||
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
||||
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
||||
logging.info(f"Num params in text encoder: {num_param}")
|
||||
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
@ -1622,15 +1638,15 @@ def run(rank, world_size, args):
|
||||
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# tokenizer=tokenizer,
|
||||
# params=params,
|
||||
# )
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
tokenizer=tokenizer,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user