mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update documentation to PromptASR (#1321)
This commit is contained in:
parent
36c60b0cf6
commit
ce372cce33
@ -1,8 +1,9 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,)
|
# Mingshuang Luo
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Xiaoyu Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -21,21 +22,35 @@
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training, using MCP style transcript:
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./zipformer/train.py \
|
./zipformer_prompt_asr/train_baseline.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
|
--transcript-style MCP \
|
||||||
|
--max-duration 1000
|
||||||
|
|
||||||
|
# For mix precision training, using UC style transcript:
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
./zipformer_prompt_asr/train_baseline.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
|
--transcript-style UC \
|
||||||
--max-duration 1000
|
--max-duration 1000
|
||||||
|
|
||||||
# To train a streaming model
|
# To train a streaming model
|
||||||
|
|
||||||
./zipformer/train.py \
|
./zipformer_prompt_asr/train_baseline.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
@ -100,7 +115,7 @@ from icefall.utils import (
|
|||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
def get_first(
|
def get_mixed_cased_with_punc(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
pre_texts: List[str],
|
pre_texts: List[str],
|
||||||
context_list: Optional[str] = None,
|
context_list: Optional[str] = None,
|
||||||
@ -479,6 +494,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript-style",
|
||||||
|
type=str,
|
||||||
|
default="UC",
|
||||||
|
choices=["UC", "MCP"],
|
||||||
|
help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations,
|
||||||
|
MCP stands for mix-cased text with punctuation.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1223,7 +1248,11 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
text_sampling_func = get_upper_only_alpha
|
if params.transcript_style == "UC":
|
||||||
|
text_sampling_func = get_upper_only_alpha
|
||||||
|
else:
|
||||||
|
text_sampling_func = get_mixed_cased_with_punc
|
||||||
|
logging.info(f"Using {params.transcript_style} style for training.")
|
||||||
logging.info(f"Text sampling func: {text_sampling_func}")
|
logging.info(f"Text sampling func: {text_sampling_func}")
|
||||||
train_dl = libriheavy.train_dataloaders(
|
train_dl = libriheavy.train_dataloaders(
|
||||||
train_cuts,
|
train_cuts,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user