diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 7075c9154..32302602c 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,21 +22,35 @@ Usage: -# For mix precision training: +# For mix precision training, using MCP style transcript: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./zipformer/train.py \ +./zipformer_prompt_asr/train_baseline.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ + --exp-dir zipformer_prompt_asr/exp \ + --transcript-style MCP \ + --max-duration 1000 + +# For mix precision training, using UC style transcript: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer_prompt_asr/train_baseline.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer_prompt_asr/exp \ + --transcript-style UC \ --max-duration 1000 # To train a streaming model -./zipformer/train.py \ +./zipformer_prompt_asr/train_baseline.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ @@ -100,7 +115,7 @@ from icefall.utils import ( LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -def get_first( +def get_mixed_cased_with_punc( texts: List[str], pre_texts: List[str], context_list: Optional[str] = None, @@ -479,6 +494,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--transcript-style", + type=str, + default="UC", + choices=["UC", "MCP"], + help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations, + MCP stands for mix-cased text with punctuation. + """, + ) + add_model_arguments(parser) return parser @@ -1223,7 +1248,11 @@ def run(rank, world_size, args): else: sampler_state_dict = None - text_sampling_func = get_upper_only_alpha + if params.transcript_style == "UC": + text_sampling_func = get_upper_only_alpha + else: + text_sampling_func = get_mixed_cased_with_punc + logging.info(f"Using {params.transcript_style} style for training.") logging.info(f"Text sampling func: {text_sampling_func}") train_dl = libriheavy.train_dataloaders( train_cuts,