mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
using monkey patch to replace models
This commit is contained in:
parent
84e4af93d7
commit
bda48291db
@ -13,12 +13,13 @@
|
||||
|
||||
Command for training is:
|
||||
```bash
|
||||
pip install -r whisper/requirements.txt
|
||||
|
||||
./prepare.sh --stage 30 --stop_stage 30
|
||||
|
||||
#fine-tuning with deepspeed zero stage 1
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--deepspeed \
|
||||
@ -27,21 +28,33 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
# fine-tuning with ddp
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir whisper/exp_medium \
|
||||
--base-lr 1e-5 \
|
||||
--model-name medium
|
||||
```
|
||||
|
||||
Command for decoding is:
|
||||
Command for decoding using fine-tuned models:
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
```
|
||||
Pretrained models, training logs, decoding logs, tensorboard and decoding results
|
||||
Command for decoding using pretrained models (before fine-tuning):
|
||||
```bash
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch -1 --avg 1 \
|
||||
--remove-whisper-encoder-input-length-restriction False \
|
||||
--beam-size 10 --max-duration 50
|
||||
```
|
||||
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/yuekai/icefall_asr_aishell_whisper>
|
||||
|
||||
|
@ -16,6 +16,29 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
# Command for decoding using pretrained models (before fine-tuning):
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch -1 --avg 1 \
|
||||
--remove-whisper-encoder-input-length-restriction False \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -29,8 +52,8 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from model import load_model
|
||||
|
||||
#from model import load_model
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
@ -104,7 +127,7 @@ def average_checkpoints(
|
||||
|
||||
def remove_punctuation(text: str or List[str]):
|
||||
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
punctuation = '!,.;:?、!,。;:?'
|
||||
punctuation = '!,.;:?、!,。;:?《》 '
|
||||
if isinstance(text, str):
|
||||
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
|
||||
return text
|
||||
@ -183,6 +206,13 @@ def get_parser():
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove-whisper-encoder-input-length-restriction",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -246,6 +276,10 @@ def decode_one_batch(
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
if not params.remove_whisper_encoder_input_length_restriction:
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
@ -404,7 +438,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = load_model(params.model_name)
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, 'cpu')
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
|
@ -4,7 +4,7 @@ git+https://github.com/lhotse-speech/lhotse
|
||||
sentencepiece
|
||||
tensorboard
|
||||
librosa
|
||||
openai-whisper==20231117
|
||||
openai-whisper==git+https://github.com/yuekaizhang/whisper.git
|
||||
zhconv
|
||||
WeTextProcessing
|
||||
deepspeed
|
@ -17,20 +17,20 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./prepare.sh
|
||||
#fine-tuning with deepspeed zero stage 1
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
|
||||
If you use --datatang-prob=0, then you don't need to run the above script.
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
# fine-tuning with ddp
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_medium \
|
||||
--base-lr 1e-5 \
|
||||
--model-name medium
|
||||
"""
|
||||
|
||||
|
||||
@ -88,7 +88,7 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
import whisper
|
||||
from model import load_model
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -227,7 +227,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
@ -744,8 +744,9 @@ def run(rank, world_size, args):
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
model = load_model(params.model_name)
|
||||
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, 'cpu')
|
||||
del model.alignment_heads
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import whisper
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
def replace_whisper_encoder_forward():
|
||||
"""
|
||||
This function monkey patches the forward method of the whisper encoder.
|
||||
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
|
||||
"""
|
||||
whisper.model.AudioEncoder.forward = forward
|
Loading…
x
Reference in New Issue
Block a user