From bda48291dbc13884de913fe0780d6d620cc13662 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 22 Jan 2024 14:41:14 +0800 Subject: [PATCH] using monkey patch to replace models --- egs/aishell/ASR/RESULTS.md | 21 +++++++-- egs/aishell/ASR/whisper/decode.py | 44 +++++++++++++++++-- egs/aishell/ASR/whisper/requirements.txt | 2 +- egs/aishell/ASR/whisper/train.py | 35 ++++++++------- .../whisper_encoder_forward_monkey_patch.py | 26 +++++++++++ 5 files changed, 102 insertions(+), 26 deletions(-) create mode 100644 egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 626a5346f..00241dca7 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -13,12 +13,13 @@ Command for training is: ```bash +pip install -r whisper/requirements.txt + ./prepare.sh --stage 30 --stop_stage 30 #fine-tuning with deepspeed zero stage 1 torchrun --nproc-per-node 8 ./whisper/train.py \ --max-duration 200 \ - --use-fp16 1 \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ --deepspeed \ @@ -27,21 +28,33 @@ torchrun --nproc-per-node 8 ./whisper/train.py \ # fine-tuning with ddp torchrun --nproc-per-node 8 ./whisper/train.py \ --max-duration 200 \ - --use-fp16 1 \ --exp-dir whisper/exp_medium \ --base-lr 1e-5 \ --model-name medium ``` -Command for decoding is: +Command for decoding using fine-tuned models: ```bash +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + python3 ./whisper/decode.py \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ --epoch 999 --avg 1 \ --beam-size 10 --max-duration 50 ``` -Pretrained models, training logs, decoding logs, tensorboard and decoding results +Command for decoding using pretrained models (before fine-tuning): +```bash +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 +``` +Fine-tuned models, training logs, decoding logs, tensorboard and decoding results are available at diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 6c09b142b..b87d15524 100755 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -16,6 +16,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Usage: +# Command for decoding using fine-tuned models: +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --beam-size 10 --max-duration 50 + +# Command for decoding using pretrained models (before fine-tuning): + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 + +""" import argparse import logging @@ -29,8 +52,8 @@ import k2 import torch import torch.nn as nn from asr_datamodule import AishellAsrDataModule -from model import load_model - +#from model import load_model +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model from icefall.decode import ( get_lattice, @@ -104,7 +127,7 @@ def average_checkpoints( def remove_punctuation(text: str or List[str]): # https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py - punctuation = '!,.;:?、!,。;:?' + punctuation = '!,.;:?、!,。;:?《》 ' if isinstance(text, str): text = re.sub(r'[{}]+'.format(punctuation), '', text).strip() return text @@ -183,6 +206,13 @@ def get_parser(): help="""The model name to use. """, ) + + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) return parser @@ -246,6 +276,10 @@ def decode_one_batch( feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device, dtype=dtype).transpose(1, 2) + if not params.remove_whisper_encoder_input_length_restriction: + T = 3000 + if feature.shape[2] < T: + feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2) supervisions = batch["supervisions"] feature_len = supervisions["num_frames"] @@ -404,7 +438,9 @@ def main(): logging.info(f"device: {device}") - model = load_model(params.model_name) + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, 'cpu') if params.epoch > 0: if params.avg > 1: start = params.epoch - params.avg diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt index 623765f9f..873ea10be 100755 --- a/egs/aishell/ASR/whisper/requirements.txt +++ b/egs/aishell/ASR/whisper/requirements.txt @@ -4,7 +4,7 @@ git+https://github.com/lhotse-speech/lhotse sentencepiece tensorboard librosa -openai-whisper==20231117 +openai-whisper==git+https://github.com/yuekaizhang/whisper.git zhconv WeTextProcessing deepspeed \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 8b133b2a4..4251536ad 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -17,20 +17,20 @@ """ Usage: -./prepare.sh +#fine-tuning with deepspeed zero stage 1 +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json -If you use --datatang-prob=0, then you don't need to run the above script. - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 +# fine-tuning with ddp +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_medium \ + --base-lr 1e-5 \ + --model-name medium """ @@ -88,7 +88,7 @@ from icefall.utils import ( ) import whisper -from model import load_model +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from label_smoothing import LabelSmoothingLoss LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -227,7 +227,7 @@ def get_parser(): parser.add_argument( "--use-fp16", type=str2bool, - default=False, + default=True, help="Whether to use half precision training.", ) @@ -744,8 +744,9 @@ def run(rank, world_size, args): logging.info(params) logging.info("About to create model") - - model = load_model(params.model_name) + + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, 'cpu') del model.alignment_heads num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py new file mode 100644 index 000000000..0f2b94adf --- /dev/null +++ b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py @@ -0,0 +1,26 @@ +import torch +import whisper + +def forward(self, x: torch.Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + +def replace_whisper_encoder_forward(): + """ + This function monkey patches the forward method of the whisper encoder. + To be called before the model is loaded, it changes whisper to process audio with any length < 30s. + """ + whisper.model.AudioEncoder.forward = forward