mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 23:24:17 +00:00
using monkey patch to replace models
This commit is contained in:
parent
84e4af93d7
commit
bda48291db
@ -13,12 +13,13 @@
|
|||||||
|
|
||||||
Command for training is:
|
Command for training is:
|
||||||
```bash
|
```bash
|
||||||
|
pip install -r whisper/requirements.txt
|
||||||
|
|
||||||
./prepare.sh --stage 30 --stop_stage 30
|
./prepare.sh --stage 30 --stop_stage 30
|
||||||
|
|
||||||
#fine-tuning with deepspeed zero stage 1
|
#fine-tuning with deepspeed zero stage 1
|
||||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir whisper/exp_large_v2 \
|
--exp-dir whisper/exp_large_v2 \
|
||||||
--model-name large-v2 \
|
--model-name large-v2 \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
@ -27,21 +28,33 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
|||||||
# fine-tuning with ddp
|
# fine-tuning with ddp
|
||||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir whisper/exp_medium \
|
--exp-dir whisper/exp_medium \
|
||||||
--base-lr 1e-5 \
|
--base-lr 1e-5 \
|
||||||
--model-name medium
|
--model-name medium
|
||||||
```
|
```
|
||||||
|
|
||||||
Command for decoding is:
|
Command for decoding using fine-tuned models:
|
||||||
```bash
|
```bash
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||||
|
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||||
|
|
||||||
python3 ./whisper/decode.py \
|
python3 ./whisper/decode.py \
|
||||||
--exp-dir whisper/exp_large_v2 \
|
--exp-dir whisper/exp_large_v2 \
|
||||||
--model-name large-v2 \
|
--model-name large-v2 \
|
||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--beam-size 10 --max-duration 50
|
--beam-size 10 --max-duration 50
|
||||||
```
|
```
|
||||||
Pretrained models, training logs, decoding logs, tensorboard and decoding results
|
Command for decoding using pretrained models (before fine-tuning):
|
||||||
|
```bash
|
||||||
|
python3 ./whisper/decode.py \
|
||||||
|
--exp-dir whisper/exp_large_v2 \
|
||||||
|
--model-name large-v2 \
|
||||||
|
--epoch -1 --avg 1 \
|
||||||
|
--remove-whisper-encoder-input-length-restriction False \
|
||||||
|
--beam-size 10 --max-duration 50
|
||||||
|
```
|
||||||
|
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
|
||||||
are available at
|
are available at
|
||||||
<https://huggingface.co/yuekai/icefall_asr_aishell_whisper>
|
<https://huggingface.co/yuekai/icefall_asr_aishell_whisper>
|
||||||
|
|
||||||
|
@ -16,6 +16,29 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
# Command for decoding using fine-tuned models:
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||||
|
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||||
|
|
||||||
|
python3 ./whisper/decode.py \
|
||||||
|
--exp-dir whisper/exp_large_v2 \
|
||||||
|
--model-name large-v2 \
|
||||||
|
--epoch 999 --avg 1 \
|
||||||
|
--beam-size 10 --max-duration 50
|
||||||
|
|
||||||
|
# Command for decoding using pretrained models (before fine-tuning):
|
||||||
|
|
||||||
|
python3 ./whisper/decode.py \
|
||||||
|
--exp-dir whisper/exp_large_v2 \
|
||||||
|
--model-name large-v2 \
|
||||||
|
--epoch -1 --avg 1 \
|
||||||
|
--remove-whisper-encoder-input-length-restriction False \
|
||||||
|
--beam-size 10 --max-duration 50
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -29,8 +52,8 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from model import load_model
|
#from model import load_model
|
||||||
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
|
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
@ -104,7 +127,7 @@ def average_checkpoints(
|
|||||||
|
|
||||||
def remove_punctuation(text: str or List[str]):
|
def remove_punctuation(text: str or List[str]):
|
||||||
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||||
punctuation = '!,.;:?、!,。;:?'
|
punctuation = '!,.;:?、!,。;:?《》 '
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
|
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
|
||||||
return text
|
return text
|
||||||
@ -183,6 +206,13 @@ def get_parser():
|
|||||||
help="""The model name to use.
|
help="""The model name to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--remove-whisper-encoder-input-length-restriction",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="replace whisper encoder forward method to remove input length restriction",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -246,6 +276,10 @@ def decode_one_batch(
|
|||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||||
|
if not params.remove_whisper_encoder_input_length_restriction:
|
||||||
|
T = 3000
|
||||||
|
if feature.shape[2] < T:
|
||||||
|
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_len = supervisions["num_frames"]
|
feature_len = supervisions["num_frames"]
|
||||||
@ -404,7 +438,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
model = load_model(params.model_name)
|
if params.remove_whisper_encoder_input_length_restriction:
|
||||||
|
replace_whisper_encoder_forward()
|
||||||
|
model = whisper.load_model(params.model_name, 'cpu')
|
||||||
if params.epoch > 0:
|
if params.epoch > 0:
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
|
@ -4,7 +4,7 @@ git+https://github.com/lhotse-speech/lhotse
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
tensorboard
|
tensorboard
|
||||||
librosa
|
librosa
|
||||||
openai-whisper==20231117
|
openai-whisper==git+https://github.com/yuekaizhang/whisper.git
|
||||||
zhconv
|
zhconv
|
||||||
WeTextProcessing
|
WeTextProcessing
|
||||||
deepspeed
|
deepspeed
|
@ -17,20 +17,20 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
./prepare.sh
|
#fine-tuning with deepspeed zero stage 1
|
||||||
|
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||||
|
--max-duration 200 \
|
||||||
|
--exp-dir whisper/exp_large_v2 \
|
||||||
|
--model-name large-v2 \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||||
|
|
||||||
If you use --datatang-prob=0, then you don't need to run the above script.
|
# fine-tuning with ddp
|
||||||
|
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
--max-duration 200 \
|
||||||
|
--exp-dir whisper/exp_medium \
|
||||||
./pruned_transducer_stateless7/train.py \
|
--base-lr 1e-5 \
|
||||||
--world-size 4 \
|
--model-name medium
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 550
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ from icefall.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
from model import load_model
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from label_smoothing import LabelSmoothingLoss
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -227,7 +227,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-fp16",
|
"--use-fp16",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -744,8 +744,9 @@ def run(rank, world_size, args):
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
model = load_model(params.model_name)
|
replace_whisper_encoder_forward()
|
||||||
|
model = whisper.load_model(params.model_name, 'cpu')
|
||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
|
the mel spectrogram of the audio
|
||||||
|
"""
|
||||||
|
x = F.gelu(self.conv1(x))
|
||||||
|
x = F.gelu(self.conv2(x))
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.ln_post(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def replace_whisper_encoder_forward():
|
||||||
|
"""
|
||||||
|
This function monkey patches the forward method of the whisper encoder.
|
||||||
|
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
|
||||||
|
"""
|
||||||
|
whisper.model.AudioEncoder.forward = forward
|
Loading…
x
Reference in New Issue
Block a user