Update documentation to PromptASR (#1321)

This commit is contained in:
marcoyang1998 2023-10-19 17:24:31 +08:00 committed by GitHub
parent 36c60b0cf6
commit ce372cce33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,)
# Zengwei Yao)
# Mingshuang Luo
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -21,21 +22,35 @@
Usage:
# For mix precision training:
# For mix precision training, using MCP style transcript:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--exp-dir zipformer_prompt_asr/exp \
--transcript-style MCP \
--max-duration 1000
# For mix precision training, using UC style transcript:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_prompt_asr/exp \
--transcript-style UC \
--max-duration 1000
# To train a streaming model
./zipformer/train.py \
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -100,7 +115,7 @@ from icefall.utils import (
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_first(
def get_mixed_cased_with_punc(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
@ -479,6 +494,16 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--transcript-style",
type=str,
default="UC",
choices=["UC", "MCP"],
help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations,
MCP stands for mix-cased text with punctuation.
""",
)
add_model_arguments(parser)
return parser
@ -1223,7 +1248,11 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
text_sampling_func = get_upper_only_alpha
if params.transcript_style == "UC":
text_sampling_func = get_upper_only_alpha
else:
text_sampling_func = get_mixed_cased_with_punc
logging.info(f"Using {params.transcript_style} style for training.")
logging.info(f"Text sampling func: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders(
train_cuts,