using monkey patch to replace models

This commit is contained in:
Yuekai Zhang 2024-01-22 14:41:14 +08:00
parent 84e4af93d7
commit bda48291db
5 changed files with 102 additions and 26 deletions

View File

@ -13,12 +13,13 @@
Command for training is:
```bash
pip install -r whisper/requirements.txt
./prepare.sh --stage 30 --stop_stage 30
#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--use-fp16 1 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
@ -27,21 +28,33 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
# fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--use-fp16 1 \
--exp-dir whisper/exp_medium \
--base-lr 1e-5 \
--model-name medium
```
Command for decoding is:
Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```
Pretrained models, training logs, decoding logs, tensorboard and decoding results
Command for decoding using pretrained models (before fine-tuning):
```bash
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch -1 --avg 1 \
--remove-whisper-encoder-input-length-restriction False \
--beam-size 10 --max-duration 50
```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_aishell_whisper>

View File

@ -16,6 +16,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# Command for decoding using fine-tuned models:
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
# Command for decoding using pretrained models (before fine-tuning):
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch -1 --avg 1 \
--remove-whisper-encoder-input-length-restriction False \
--beam-size 10 --max-duration 50
"""
import argparse
import logging
@ -29,8 +52,8 @@ import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from model import load_model
#from model import load_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import (
get_lattice,
@ -104,7 +127,7 @@ def average_checkpoints(
def remove_punctuation(text: str or List[str]):
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
punctuation = '!,.;:?、!,。;:?'
punctuation = '!,.;:?、!,。;:?《》 '
if isinstance(text, str):
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
return text
@ -183,6 +206,13 @@ def get_parser():
help="""The model name to use.
""",
)
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
return parser
@ -246,6 +276,10 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
if not params.remove_whisper_encoder_input_length_restriction:
T = 3000
if feature.shape[2] < T:
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
@ -404,7 +438,9 @@ def main():
logging.info(f"device: {device}")
model = load_model(params.model_name)
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu')
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg

View File

@ -4,7 +4,7 @@ git+https://github.com/lhotse-speech/lhotse
sentencepiece
tensorboard
librosa
openai-whisper==20231117
openai-whisper==git+https://github.com/yuekaizhang/whisper.git
zhconv
WeTextProcessing
deepspeed

View File

@ -17,20 +17,20 @@
"""
Usage:
./prepare.sh
#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
If you use --datatang-prob=0, then you don't need to run the above script.
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 550
# fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_medium \
--base-lr 1e-5 \
--model-name medium
"""
@ -88,7 +88,7 @@ from icefall.utils import (
)
import whisper
from model import load_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from label_smoothing import LabelSmoothingLoss
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -227,7 +227,7 @@ def get_parser():
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
default=True,
help="Whether to use half precision training.",
)
@ -744,8 +744,9 @@ def run(rank, world_size, args):
logging.info(params)
logging.info("About to create model")
model = load_model(params.model_name)
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu')
del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

View File

@ -0,0 +1,26 @@
import torch
import whisper
def forward(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
def replace_whisper_encoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
"""
whisper.model.AudioEncoder.forward = forward