mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
update training script
This commit is contained in:
parent
9f48d06581
commit
7a9c18fc79
@ -22,7 +22,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
(1) Non-streaming model, without context list
|
(1) Non-streaming model, **without** context list
|
||||||
|
|
||||||
./zipformer_prompt_asr/train.py \
|
./zipformer_prompt_asr/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
@ -34,12 +34,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--exp-dir zipformer_prompt_asr/exp \
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
--max-duration 1000 \
|
--max-duration 1000 \
|
||||||
--memory-layer 0 \
|
--memory-layer 0 \
|
||||||
--memory-dim 768 \
|
|
||||||
--text-encoder-type BERT \
|
--text-encoder-type BERT \
|
||||||
|
--text-encoder-dim 768 \
|
||||||
--use-style-prompt True \
|
--use-style-prompt True \
|
||||||
--use-context-list False
|
--use-context-list False
|
||||||
|
|
||||||
(2) Non-streaming model, with context list
|
(2) Non-streaming model, **with** context list
|
||||||
|
|
||||||
./zipformer_prompt_asr/train.py \
|
./zipformer_prompt_asr/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
@ -51,10 +51,11 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--exp-dir zipformer_prompt_asr/exp \
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
--max-duration 1000 \
|
--max-duration 1000 \
|
||||||
--memory-layer 0 \
|
--memory-layer 0 \
|
||||||
--memory-dim 768 \
|
|
||||||
--text-encoder-type BERT \
|
--text-encoder-type BERT \
|
||||||
|
--text-encoder-dim 768 \
|
||||||
--use-style-prompt True \
|
--use-style-prompt True \
|
||||||
--use-context-list True \
|
--use-context-list True \
|
||||||
|
--top-k 10000 \
|
||||||
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
|
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
|
||||||
|
|
||||||
|
|
||||||
@ -64,6 +65,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -134,22 +136,6 @@ style_transforms = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def random_sampling(texts: List[str]) -> str:
|
|
||||||
return random.choice(texts)
|
|
||||||
|
|
||||||
|
|
||||||
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
|
||||||
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
|
|
||||||
i = random.randint(0, 1)
|
|
||||||
out = {
|
|
||||||
"text": texts[i],
|
|
||||||
"pre_text": pre_texts[i],
|
|
||||||
"style_text": "",
|
|
||||||
"transform_ids": 0,
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
||||||
out = {
|
out = {
|
||||||
"text": texts[0],
|
"text": texts[0],
|
||||||
@ -313,7 +299,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal",
|
"--causal",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=False,
|
||||||
help="If True, use causal version of model.",
|
help="If True, use causal version of model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -348,6 +334,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Type of the text encoder",
|
help="Type of the text encoder",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="Dimension of the text encoder",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text-encoder-adapter",
|
"--text-encoder-adapter",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -762,18 +755,30 @@ class TextEmbedding(nn.Module):
|
|||||||
|
|
||||||
def get_text_encoder(params: AttributeDict) -> nn.Module:
|
def get_text_encoder(params: AttributeDict) -> nn.Module:
|
||||||
# Return a text encoder
|
# Return a text encoder
|
||||||
if params.text_encoder_type == "BERT":
|
if params.text_encoder_type == "BERT": # This is a BERT-base-cased
|
||||||
from transformers import BertModel
|
from transformers import BertModel
|
||||||
|
|
||||||
# This is a BERT-base-cased
|
|
||||||
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
||||||
|
if os.path.exists("data/models/bert-base-cased"):
|
||||||
|
model = BertModel.from_pretrained("data/models/bert-base-cased")
|
||||||
|
else:
|
||||||
model = BertModel.from_pretrained("bert-base-cased")
|
model = BertModel.from_pretrained("bert-base-cased")
|
||||||
elif params.text_encoder_type == "DistilBERT":
|
assert params.text_encoder_dim == 768
|
||||||
from transformers import DistilBertModel
|
elif params.text_encoder_type == "BERT-large":
|
||||||
|
from transformers import BertModel
|
||||||
|
|
||||||
|
logging.info("Loading pre-trained BERT-large-uncased as text encoder")
|
||||||
|
if os.path.exists("data/models/bert-large-uncased"):
|
||||||
|
model = BertModel.from_pretrained("data/models/bert-large-uncased")
|
||||||
|
else:
|
||||||
|
model = BertModel.from_pretrained("bert-large-uncased")
|
||||||
|
assert params.text_encoder_dim == 1024
|
||||||
|
elif params.text_encoder_type == "DistilBERT":
|
||||||
|
from transformers import DistilBertModel # This is a DistilBERT-base-cased
|
||||||
|
|
||||||
# This is a DistilBERT-base-cased
|
|
||||||
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
||||||
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
||||||
|
assert params.text_encoder_dim == 768
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
@ -786,7 +791,18 @@ def get_tokenizer(params: AttributeDict):
|
|||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
# This is a BERT-base-cased
|
# This is a BERT-base-cased
|
||||||
|
if os.path.exists("data/models/bert-base-cased"):
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("data/models/bert-base-cased")
|
||||||
|
else:
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
elif params.text_encoder_type == "BERT-large":
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
# This is a BERT-large-uncased
|
||||||
|
if os.path.exists("data/models/bert-large-uncased"):
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("data/models/bert-large-uncased")
|
||||||
|
else:
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
||||||
elif params.text_encoder_type == "DistilBERT":
|
elif params.text_encoder_type == "DistilBERT":
|
||||||
from transformers import DistilBertTokenizer
|
from transformers import DistilBertTokenizer
|
||||||
|
|
||||||
@ -816,7 +832,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
chunk_size=_to_int_tuple(params.chunk_size),
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
memory_dim=params.text_encoder_dim, # This is fixed as the BERT base model is 768-D
|
||||||
memory_layer=params.memory_layer,
|
memory_layer=params.memory_layer,
|
||||||
memory_dropout_rate=params.memory_dropout_rate,
|
memory_dropout_rate=params.memory_dropout_rate,
|
||||||
)
|
)
|
||||||
@ -848,13 +864,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder_embed = get_encoder_embed(params)
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
|
|
||||||
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
||||||
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
||||||
logging.info(f"Num params in text encoder: {num_param}")
|
logging.info(f"Num params in text encoder: {num_param}")
|
||||||
|
|
||||||
|
encoder_embed = get_encoder_embed(params)
|
||||||
|
encoder = get_encoder_model(params)
|
||||||
|
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
@ -1622,15 +1638,15 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||||
)
|
)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
# scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
# model=model,
|
model=model,
|
||||||
# train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
# optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
# sp=sp,
|
sp=sp,
|
||||||
# tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
# params=params,
|
params=params,
|
||||||
# )
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user