diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index 829ef4e2d..bc6a94613 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -194,10 +194,10 @@ class SPEECH_LLM(nn.Module): def forward( self, - fbank: torch.Tensor = None, - input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor = None, - labels: torch.LongTensor = None, + fbank: torch.Tensor, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + labels: torch.LongTensor, ): encoder_outs = self.encoder(fbank) diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/decode.py index 3036b471e..2b757c2de 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/decode.py @@ -3,6 +3,7 @@ # Fangjun Kuang, # Wei Kang) # 2024 Yuekai Zhang +# 2025 Yifan Yang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,31 +20,17 @@ # limitations under the License. """ Usage: -# Command for decoding using fine-tuned models: - -pip install huggingface_hub['cli'] -mkdir -p models/whisper models/qwen models/checkpoint -huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B - -# For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt -# For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct - -mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B -ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt - -python3 ./whisper_llm_zh/decode.py \ +python3 ./zipformer_llm_zh/decode.py \ --max-duration 80 \ - --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ - --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --exp-dir zipformer_llm_zh/exp \ + --speech-encoder-path-or-name models/zipformer/epoch-999.pt \ --llm-path-or-name models/qwen \ - --epoch 999 --avg 1 \ + --epoch 999 \ + --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --use-lora True --dataset aishell + --use-lora True \ + --dataset aishell """ import argparse @@ -56,15 +43,22 @@ import k2 import torch import torch.nn as nn import transformers -import whisper from asr_datamodule import AsrDataModule from lhotse.cut import Cut from model import SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training -from train import DEFAULT_SPEECH_TOKEN +from train import ( + DEFAULT_SPEECH_TOKEN, + _to_int_tuple, + add_model_arguments, + get_encoder_embed, + get_encoder_model, + get_params, + load_model_params, +) from transformers import AutoModelForCausalLM, AutoTokenizer -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from zipformer import Zipformer2 from icefall.checkpoint import load_checkpoint from icefall.env import get_env_info @@ -129,43 +123,6 @@ def average_checkpoints( return avg -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--llm-path-or-name", - type=str, - default="/workspace/asr/Qwen1.5-0.5B-Chat", - help="Path or name of the large language model.", - ) - - parser.add_argument( - "--speech-encoder-path-or-name", - type=str, - default="whisper-large-v2", - help="Path or name of the speech encoder.", - ) - - parser.add_argument( - "--encoder-projector-ds-rate", - type=int, - default=8, - help="Downsample rate for the encoder projector.", - ) - - parser.add_argument( - "--use-flash-attn", - type=str2bool, - default=True, - help="Whether to use flash attention.", - ) - - parser.add_argument( - "--use-lora", - type=str2bool, - default=True, - help="Whether to use lora fine-tuned llm checkpoint.", - ) - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -207,17 +164,10 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="whisper/exp", + default="zipformer/exp", help="The experiment dir", ) - parser.add_argument( - "--remove-whisper-encoder-input-length-restriction", - type=str2bool, - default=True, - help="replace whisper encoder forward method to remove input length restriction", - ) - parser.add_argument( "--dataset", type=str, @@ -230,15 +180,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - "env_info": get_env_info(), - } - ) - return params - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -299,28 +240,13 @@ def decode_one_batch( return input_ids, attention_mask - dtype = torch.float32 device = model.llm.device feature = batch["inputs"] assert feature.ndim == 3 - feature = feature.to(device, dtype=dtype).transpose(1, 2) - if not params.remove_whisper_encoder_input_length_restriction: - T = 3000 - if feature.shape[2] < T: - feature = torch.cat( - [ - feature, - torch.zeros( - feature.shape[0], feature.shape[1], T - feature.shape[2] - ).to(device, dtype=dtype), - ], - 2, - ) supervisions = batch["supervisions"] - feature_len = supervisions["num_frames"] - feature_len = feature_len.to(device, dtype=dtype) + feature_lens = supervisions["num_frames"] messages = [ [ @@ -332,7 +258,10 @@ def decode_one_batch( input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128) generated_ids = model.decode( - feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) + feature.to(device), + feature_lens.to(device), + input_ids.to(device, dtype=torch.long), + attention_mask.to(device), ) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) @@ -471,12 +400,15 @@ def main(): logging.info(f"device: {device}") - if params.remove_whisper_encoder_input_length_restriction: - replace_whisper_encoder_forward() + speech_encoder_embed = get_encoder_embed(params) + speech_encoder = get_encoder_model(params) + load_model_params( + params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed" + ) + load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder") + + speech_encoder_dim = max(_to_int_tuple(params.encoder_dim)) - whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") - speech_encoder = whisper_model.encoder - speech_encoder_dim = whisper_model.dims.n_audio_state tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: @@ -528,6 +460,7 @@ def main(): ) model = SPEECH_LLM( + speech_encoder_embed, speech_encoder, llm, encoder_projector, diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/encoder_interface.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py index 829ef4e2d..5f0d4b8e5 100644 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py @@ -1,7 +1,12 @@ +from typing import Tuple + import torch +from encoder_interface import EncoderInterface from torch import nn from transformers.trainer_pt_utils import LabelSmoother +from icefall.utils import make_pad_mask + IGNORE_TOKEN_ID = LabelSmoother.ignore_index @@ -55,11 +60,13 @@ class SPEECH_LLM(nn.Module): def __init__( self, - encoder: nn.Module, + encoder_embed: nn.Module, + encoder: EncoderInterface, llm: nn.Module, encoder_projector: nn.Module, ): super().__init__() + self.encoder_embed = encoder_embed self.encoder = encoder self.llm = llm self.encoder_projector = encoder_projector @@ -192,14 +199,46 @@ class SPEECH_LLM(nn.Module): return final_embedding, final_attention_mask, final_labels, position_ids + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + def forward( self, - fbank: torch.Tensor = None, - input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor = None, - labels: torch.LongTensor = None, + fbank: torch.Tensor, + fbank_lens: torch.Tensor, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + labels: torch.LongTensor, ): - encoder_outs = self.encoder(fbank) + encoder_outs, _ = self.forward_encoder(fbank, fbank_lens) speech_features = self.encoder_projector(encoder_outs) @@ -229,15 +268,17 @@ class SPEECH_LLM(nn.Module): def decode( self, - fbank: torch.Tensor = None, - input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor = None, + fbank: torch.Tensor, + fbank_lens: torch.Tensor, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, **kwargs, ): + encoder_outs, _ = self.forward_encoder(fbank, fbank_lens) - encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) speech_features = speech_features.to(torch.float16) + inputs_embeds = self.llm.get_input_embeddings()(input_ids) ( inputs_embeds, diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/scaling.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/subsampling.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py index 7947a60a5..77c6a9b95 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py @@ -18,28 +18,17 @@ # limitations under the License. """ Usage: -# fine-tuning with whisper and Qwen2 -pip install huggingface_hub['cli'] -mkdir -p models/whisper models/qwen - -# For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt -# For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - -# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct - -torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ +torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \ --max-duration 200 \ - --exp-dir ./whisper_llm_zh/exp_test \ - --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --exp-dir ./zipformer_llm_zh/exp_test \ + --speech-encoder-path-or-name models/zipformer/exp/epoch-999.pt \ --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ --manifest-dir data/fbank \ --deepspeed \ - --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --deepspeed_config ./zipformer_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora True --unfreeze-llm True + --use-lora True \ + --unfreeze-llm True """ import argparse @@ -53,7 +42,6 @@ import deepspeed import torch import torch.nn as nn import transformers -import whisper from asr_datamodule import AsrDataModule from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from lhotse.cut import Cut @@ -61,10 +49,12 @@ from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset from peft import LoraConfig, get_peft_model +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling from torch import Tensor from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from zipformer import Zipformer2 from icefall.dist import get_rank, get_world_size from icefall.env import get_env_info @@ -90,14 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--speech-encoder-path-or-name", type=str, - default="whisper-large-v2", + default="zipformer", help="Path or name of the speech encoder.", ) parser.add_argument( "--encoder-projector-ds-rate", type=int, - default=8, + default=4, help="Downsample rate for the encoder projector.", ) parser.add_argument( @@ -121,6 +111,185 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to unfreeze llm during training.", ) + # Zipformer + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + +def load_model_params(ckpt: str, model: nn.Module, module: str, strict: bool = True): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + module (str): Module to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + logging.info(f"Loading parameters starting with prefix {module}") + module_prefix = module.strip() + "." + src_keys = [ + k[len(module_prefix) :] + for k in src_state_dict.keys() + if k.startswith(module_prefix) + ] + dst_keys = [k for k in dst_state_dict.keys()] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(module_prefix + key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + def get_parser(): parser = argparse.ArgumentParser( @@ -154,7 +323,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="whisper_qwen/exp", + default="zipformer_llm_zh/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -166,7 +335,7 @@ def get_parser(): type=str, default=None, help="""The path to the pretrained model if it is not None. Training will - start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt + start from this model. e.g. ./wenetspeech/ASR/zipformer/exp/epoch-999.pt """, ) @@ -231,7 +400,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "allowed_excess_duration_ratio": 0.1, - "subsampling_factor": 2, "frame_shift_ms": 10, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -241,6 +409,9 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 5000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. "env_info": get_env_info(), } ) @@ -327,14 +498,13 @@ def compute_loss( return input_ids, attention_mask, target_ids device = next(model.parameters()).device + feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - feature = feature.transpose(1, 2) # (N, C, T) - batch_idx_train = params.batch_idx_train supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"] + texts = batch["supervisions"]["text"] messages = [] @@ -353,7 +523,8 @@ def compute_loss( with torch.set_grad_enabled(is_training): model_outputs, acc = model( - fbank=feature, + fbank=feature.to(device), + fbank_lens=feature_lens.to(device), input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=target_ids.to(device), @@ -364,7 +535,6 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - feature_lens = supervisions["num_frames"] info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. @@ -378,7 +548,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - tokenizer: whisper.tokenizer.Tokenizer, + tokenizer: AutoTokenizer, model: nn.Module, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -586,10 +756,16 @@ def run(rank, world_size, args): logging.info("About to create model") - replace_whisper_encoder_forward() - whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") - speech_encoder = whisper_model.encoder - speech_encoder_dim = whisper_model.dims.n_audio_state + speech_encoder_embed = get_encoder_embed(params) + speech_encoder = get_encoder_model(params) + load_model_params( + params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed" + ) + load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder") + + speech_encoder_dim = max(_to_int_tuple(params.encoder_dim)) + for name, param in speech_encoder_embed.named_parameters(): + param.requires_grad = False for name, param in speech_encoder.named_parameters(): param.requires_grad = False @@ -646,6 +822,7 @@ def run(rank, world_size, args): ) model = SPEECH_LLM( + speech_encoder_embed, speech_encoder, llm, encoder_projector, @@ -790,7 +967,6 @@ def display_and_save_batch( logging.info(f"Saving batch to {filename}") torch.save(batch, filename) - supervisions = batch["supervisions"] features = batch["inputs"] logging.info(f"features shape: {features.shape}") diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/zipformer.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file