mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
support zipformer encoder
update update update update fix reformat support infer update
This commit is contained in:
parent
211c01bc1d
commit
489c42b45e
@ -194,10 +194,10 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor,
|
||||||
attention_mask: torch.Tensor = None,
|
attention_mask: torch.Tensor,
|
||||||
labels: torch.LongTensor = None,
|
labels: torch.LongTensor,
|
||||||
):
|
):
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# Fangjun Kuang,
|
# Fangjun Kuang,
|
||||||
# Wei Kang)
|
# Wei Kang)
|
||||||
# 2024 Yuekai Zhang
|
# 2024 Yuekai Zhang
|
||||||
|
# 2025 Yifan Yang
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -19,31 +20,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
# Command for decoding using fine-tuned models:
|
python3 ./zipformer_llm_zh/decode.py \
|
||||||
|
|
||||||
pip install huggingface_hub['cli']
|
|
||||||
mkdir -p models/whisper models/qwen models/checkpoint
|
|
||||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
|
||||||
|
|
||||||
# For aishell fine-tuned whisper model
|
|
||||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
|
||||||
# For multi-hans fine-tuned whisper model
|
|
||||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
|
||||||
|
|
||||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
|
||||||
|
|
||||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
|
||||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
|
||||||
|
|
||||||
python3 ./whisper_llm_zh/decode.py \
|
|
||||||
--max-duration 80 \
|
--max-duration 80 \
|
||||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
--exp-dir zipformer_llm_zh/exp \
|
||||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
--speech-encoder-path-or-name models/zipformer/epoch-999.pt \
|
||||||
--llm-path-or-name models/qwen \
|
--llm-path-or-name models/qwen \
|
||||||
--epoch 999 --avg 1 \
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--use-lora True --dataset aishell
|
--use-lora True \
|
||||||
|
--dataset aishell
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -56,15 +43,22 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
import whisper
|
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from model import SPEECH_LLM, EncoderProjector
|
from model import SPEECH_LLM, EncoderProjector
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
from train import DEFAULT_SPEECH_TOKEN
|
from train import (
|
||||||
|
DEFAULT_SPEECH_TOKEN,
|
||||||
|
_to_int_tuple,
|
||||||
|
add_model_arguments,
|
||||||
|
get_encoder_embed,
|
||||||
|
get_encoder_model,
|
||||||
|
get_params,
|
||||||
|
load_model_params,
|
||||||
|
)
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
@ -129,43 +123,6 @@ def average_checkpoints(
|
|||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--llm-path-or-name",
|
|
||||||
type=str,
|
|
||||||
default="/workspace/asr/Qwen1.5-0.5B-Chat",
|
|
||||||
help="Path or name of the large language model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--speech-encoder-path-or-name",
|
|
||||||
type=str,
|
|
||||||
default="whisper-large-v2",
|
|
||||||
help="Path or name of the speech encoder.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-projector-ds-rate",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Downsample rate for the encoder projector.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-flash-attn",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to use flash attention.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-lora",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to use lora fine-tuned llm checkpoint.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -207,17 +164,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="whisper/exp",
|
default="zipformer/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--remove-whisper-encoder-input-length-restriction",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="replace whisper encoder forward method to remove input length restriction",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset",
|
"--dataset",
|
||||||
type=str,
|
type=str,
|
||||||
@ -230,15 +180,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -299,28 +240,13 @@ def decode_one_batch(
|
|||||||
|
|
||||||
return input_ids, attention_mask
|
return input_ids, attention_mask
|
||||||
|
|
||||||
dtype = torch.float32
|
|
||||||
device = model.llm.device
|
device = model.llm.device
|
||||||
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
|
||||||
if not params.remove_whisper_encoder_input_length_restriction:
|
|
||||||
T = 3000
|
|
||||||
if feature.shape[2] < T:
|
|
||||||
feature = torch.cat(
|
|
||||||
[
|
|
||||||
feature,
|
|
||||||
torch.zeros(
|
|
||||||
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
|
||||||
).to(device, dtype=dtype),
|
|
||||||
],
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_len = supervisions["num_frames"]
|
feature_lens = supervisions["num_frames"]
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
[
|
[
|
||||||
@ -332,7 +258,10 @@ def decode_one_batch(
|
|||||||
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
|
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
|
||||||
|
|
||||||
generated_ids = model.decode(
|
generated_ids = model.decode(
|
||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature.to(device),
|
||||||
|
feature_lens.to(device),
|
||||||
|
input_ids.to(device, dtype=torch.long),
|
||||||
|
attention_mask.to(device),
|
||||||
)
|
)
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
@ -471,12 +400,15 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
if params.remove_whisper_encoder_input_length_restriction:
|
speech_encoder_embed = get_encoder_embed(params)
|
||||||
replace_whisper_encoder_forward()
|
speech_encoder = get_encoder_model(params)
|
||||||
|
load_model_params(
|
||||||
|
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
|
||||||
|
)
|
||||||
|
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
|
||||||
|
|
||||||
|
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
|
||||||
|
|
||||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
|
||||||
speech_encoder = whisper_model.encoder
|
|
||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
|
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
@ -528,6 +460,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
|
speech_encoder_embed,
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
llm,
|
llm,
|
||||||
encoder_projector,
|
encoder_projector,
|
||||||
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/encoder_interface.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/encoder_interface.py
|
@ -1,7 +1,12 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
|
||||||
|
|
||||||
@ -55,11 +60,13 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: nn.Module,
|
encoder_embed: nn.Module,
|
||||||
|
encoder: EncoderInterface,
|
||||||
llm: nn.Module,
|
llm: nn.Module,
|
||||||
encoder_projector: nn.Module,
|
encoder_projector: nn.Module,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.encoder_projector = encoder_projector
|
self.encoder_projector = encoder_projector
|
||||||
@ -192,14 +199,46 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||||
|
|
||||||
|
def forward_encoder(
|
||||||
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute encoder outputs.
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
"""
|
||||||
|
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor,
|
||||||
input_ids: torch.LongTensor = None,
|
fbank_lens: torch.Tensor,
|
||||||
attention_mask: torch.Tensor = None,
|
input_ids: torch.LongTensor,
|
||||||
labels: torch.LongTensor = None,
|
attention_mask: torch.Tensor,
|
||||||
|
labels: torch.LongTensor,
|
||||||
):
|
):
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
|
||||||
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
@ -229,15 +268,17 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor,
|
||||||
input_ids: torch.LongTensor = None,
|
fbank_lens: torch.Tensor,
|
||||||
attention_mask: torch.Tensor = None,
|
input_ids: torch.LongTensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
|
||||||
|
|
||||||
encoder_outs = self.encoder(fbank)
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
speech_features = speech_features.to(torch.float16)
|
speech_features = speech_features.to(torch.float16)
|
||||||
|
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
(
|
(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/scaling.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/scaling.py
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/subsampling.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/subsampling.py
|
@ -18,28 +18,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
# fine-tuning with whisper and Qwen2
|
torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
|
||||||
pip install huggingface_hub['cli']
|
|
||||||
mkdir -p models/whisper models/qwen
|
|
||||||
|
|
||||||
# For aishell fine-tuned whisper model
|
|
||||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
|
||||||
# For multi-hans fine-tuned whisper model
|
|
||||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
|
||||||
|
|
||||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
|
||||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
|
||||||
|
|
||||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir ./whisper_llm_zh/exp_test \
|
--exp-dir ./zipformer_llm_zh/exp_test \
|
||||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
--speech-encoder-path-or-name models/zipformer/exp/epoch-999.pt \
|
||||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
--deepspeed_config ./zipformer_llm_zh/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--use-lora True --unfreeze-llm True
|
--use-lora True \
|
||||||
|
--unfreeze-llm True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -53,7 +42,6 @@ import deepspeed
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
import whisper
|
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -61,10 +49,12 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
|
from scaling import ScheduledFloat
|
||||||
|
from subsampling import Conv2dSubsampling
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall.dist import get_rank, get_world_size
|
from icefall.dist import get_rank, get_world_size
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
@ -90,14 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speech-encoder-path-or-name",
|
"--speech-encoder-path-or-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="whisper-large-v2",
|
default="zipformer",
|
||||||
help="Path or name of the speech encoder.",
|
help="Path or name of the speech encoder.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-projector-ds-rate",
|
"--encoder-projector-ds-rate",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=4,
|
||||||
help="Downsample rate for the encoder projector.",
|
help="Downsample rate for the encoder projector.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -121,6 +111,185 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Whether to unfreeze llm during training.",
|
help="Whether to unfreeze llm during training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Zipformer
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-encoder-layers",
|
||||||
|
type=str,
|
||||||
|
default="2,2,3,4,3,2",
|
||||||
|
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--downsampling-factor",
|
||||||
|
type=str,
|
||||||
|
default="1,2,4,8,4,2",
|
||||||
|
help="Downsampling factor for each stack of encoder layers.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feedforward-dim",
|
||||||
|
type=str,
|
||||||
|
default="512,768,1024,1536,1024,768",
|
||||||
|
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-heads",
|
||||||
|
type=str,
|
||||||
|
default="4,4,4,8,4,4",
|
||||||
|
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-dim",
|
||||||
|
type=str,
|
||||||
|
default="192,256,384,512,384,256",
|
||||||
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--query-head-dim",
|
||||||
|
type=str,
|
||||||
|
default="32",
|
||||||
|
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--value-head-dim",
|
||||||
|
type=str,
|
||||||
|
default="12",
|
||||||
|
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pos-head-dim",
|
||||||
|
type=str,
|
||||||
|
default="4",
|
||||||
|
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pos-dim",
|
||||||
|
type=int,
|
||||||
|
default="48",
|
||||||
|
help="Positional-encoding embedding dimension",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-unmasked-dim",
|
||||||
|
type=str,
|
||||||
|
default="192,192,256,256,256,192",
|
||||||
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
|
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cnn-module-kernel",
|
||||||
|
type=str,
|
||||||
|
default="31,31,15,15,15,31",
|
||||||
|
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
||||||
|
"a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use causal version of model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-size",
|
||||||
|
type=str,
|
||||||
|
default="16,32,64,-1",
|
||||||
|
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
||||||
|
" Must be just -1 if --causal=False",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context-frames",
|
||||||
|
type=str,
|
||||||
|
default="64,128,256,-1",
|
||||||
|
help="Maximum left-contexts for causal training, measured in frames which will "
|
||||||
|
"be converted to a number of chunks. If splitting into chunks, "
|
||||||
|
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_params(ckpt: str, model: nn.Module, module: str, strict: bool = True):
|
||||||
|
"""Load model params from checkpoint
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt (str): Path to the checkpoint
|
||||||
|
model (nn.Module): model to be loaded
|
||||||
|
module (str): Module to be initialized
|
||||||
|
|
||||||
|
"""
|
||||||
|
logging.info(f"Loading checkpoint from {ckpt}")
|
||||||
|
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||||
|
|
||||||
|
src_state_dict = checkpoint["model"]
|
||||||
|
dst_state_dict = model.state_dict()
|
||||||
|
logging.info(f"Loading parameters starting with prefix {module}")
|
||||||
|
module_prefix = module.strip() + "."
|
||||||
|
src_keys = [
|
||||||
|
k[len(module_prefix) :]
|
||||||
|
for k in src_state_dict.keys()
|
||||||
|
if k.startswith(module_prefix)
|
||||||
|
]
|
||||||
|
dst_keys = [k for k in dst_state_dict.keys()]
|
||||||
|
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
||||||
|
for key in src_keys:
|
||||||
|
dst_state_dict[key] = src_state_dict.pop(module_prefix + key)
|
||||||
|
|
||||||
|
model.load_state_dict(dst_state_dict, strict=strict)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_int_tuple(s: str):
|
||||||
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||||
|
# encoder_embed converts the input of shape (N, T, num_features)
|
||||||
|
# to the shape (N, (T - 7) // 2, encoder_dims).
|
||||||
|
# That is, it does two things simultaneously:
|
||||||
|
# (1) subsampling: T -> (T - 7) // 2
|
||||||
|
# (2) embedding: num_features -> encoder_dims
|
||||||
|
# In the normal configuration, we will downsample once more at the end
|
||||||
|
# by a factor of 2, and most of the encoder stacks will run at a lower
|
||||||
|
# sampling rate.
|
||||||
|
encoder_embed = Conv2dSubsampling(
|
||||||
|
in_channels=params.feature_dim,
|
||||||
|
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
||||||
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
|
)
|
||||||
|
return encoder_embed
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
|
encoder = Zipformer2(
|
||||||
|
output_downsampling_factor=2,
|
||||||
|
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
||||||
|
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||||
|
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||||
|
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
||||||
|
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||||
|
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
||||||
|
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||||
|
pos_dim=params.pos_dim,
|
||||||
|
num_heads=_to_int_tuple(params.num_heads),
|
||||||
|
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||||
|
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
||||||
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
|
warmup_batches=4000.0,
|
||||||
|
causal=params.causal,
|
||||||
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
|
)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -154,7 +323,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="whisper_qwen/exp",
|
default="zipformer_llm_zh/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -166,7 +335,7 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="""The path to the pretrained model if it is not None. Training will
|
help="""The path to the pretrained model if it is not None. Training will
|
||||||
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
|
start from this model. e.g. ./wenetspeech/ASR/zipformer/exp/epoch-999.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -231,7 +400,6 @@ def get_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"allowed_excess_duration_ratio": 0.1,
|
"allowed_excess_duration_ratio": 0.1,
|
||||||
"subsampling_factor": 2,
|
|
||||||
"frame_shift_ms": 10,
|
"frame_shift_ms": 10,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
@ -241,6 +409,9 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 5000,
|
"valid_interval": 5000,
|
||||||
|
# parameters for zipformer
|
||||||
|
"feature_dim": 80,
|
||||||
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -327,14 +498,13 @@ def compute_loss(
|
|||||||
return input_ids, attention_mask, target_ids
|
return input_ids, attention_mask, target_ids
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
|
||||||
feature = feature.transpose(1, 2) # (N, C, T)
|
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"]
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
@ -353,7 +523,8 @@ def compute_loss(
|
|||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
model_outputs, acc = model(
|
model_outputs, acc = model(
|
||||||
fbank=feature,
|
fbank=feature.to(device),
|
||||||
|
fbank_lens=feature_lens.to(device),
|
||||||
input_ids=input_ids.to(device),
|
input_ids=input_ids.to(device),
|
||||||
attention_mask=attention_mask.to(device),
|
attention_mask=attention_mask.to(device),
|
||||||
labels=target_ids.to(device),
|
labels=target_ids.to(device),
|
||||||
@ -364,7 +535,6 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
feature_lens = supervisions["num_frames"]
|
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
@ -378,7 +548,7 @@ def compute_loss(
|
|||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: whisper.tokenizer.Tokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -586,10 +756,16 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
replace_whisper_encoder_forward()
|
speech_encoder_embed = get_encoder_embed(params)
|
||||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
speech_encoder = get_encoder_model(params)
|
||||||
speech_encoder = whisper_model.encoder
|
load_model_params(
|
||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
|
||||||
|
)
|
||||||
|
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
|
||||||
|
|
||||||
|
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
|
||||||
|
for name, param in speech_encoder_embed.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
for name, param in speech_encoder.named_parameters():
|
for name, param in speech_encoder.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@ -646,6 +822,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
|
speech_encoder_embed,
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
llm,
|
llm,
|
||||||
encoder_projector,
|
encoder_projector,
|
||||||
@ -790,7 +967,6 @@ def display_and_save_batch(
|
|||||||
logging.info(f"Saving batch to {filename}")
|
logging.info(f"Saving batch to {filename}")
|
||||||
torch.save(batch, filename)
|
torch.save(batch, filename)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
features = batch["inputs"]
|
features = batch["inputs"]
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/zipformer.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user