support zipformer encoder

update

update

update

update

fix

reformat

support infer

update
This commit is contained in:
Yifan Yang 2025-05-08 04:31:34 +00:00 committed by yfyeung
parent 211c01bc1d
commit 489c42b45e
8 changed files with 304 additions and 150 deletions

View File

@ -194,10 +194,10 @@ class SPEECH_LLM(nn.Module):
def forward( def forward(
self, self,
fbank: torch.Tensor = None, fbank: torch.Tensor,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor,
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor,
labels: torch.LongTensor = None, labels: torch.LongTensor,
): ):
encoder_outs = self.encoder(fbank) encoder_outs = self.encoder(fbank)

View File

@ -3,6 +3,7 @@
# Fangjun Kuang, # Fangjun Kuang,
# Wei Kang) # Wei Kang)
# 2024 Yuekai Zhang # 2024 Yuekai Zhang
# 2025 Yifan Yang
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -19,31 +20,17 @@
# limitations under the License. # limitations under the License.
""" """
Usage: Usage:
# Command for decoding using fine-tuned models: python3 ./zipformer_llm_zh/decode.py \
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \ --max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ --exp-dir zipformer_llm_zh/exp \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ --speech-encoder-path-or-name models/zipformer/epoch-999.pt \
--llm-path-or-name models/qwen \ --llm-path-or-name models/qwen \
--epoch 999 --avg 1 \ --epoch 999 \
--avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --dataset aishell --use-lora True \
--dataset aishell
""" """
import argparse import argparse
@ -56,15 +43,22 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from lhotse.cut import Cut from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector from model import SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import DEFAULT_SPEECH_TOKEN from train import (
DEFAULT_SPEECH_TOKEN,
_to_int_tuple,
add_model_arguments,
get_encoder_embed,
get_encoder_model,
get_params,
load_model_params,
)
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zipformer import Zipformer2
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
@ -129,43 +123,6 @@ def average_checkpoints(
return avg return avg
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
help="Path or name of the large language model.",
)
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=True,
help="Whether to use lora fine-tuned llm checkpoint.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -207,17 +164,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="whisper/exp", default="zipformer/exp",
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,
@ -230,15 +180,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"env_info": get_env_info(),
}
)
return params
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -299,28 +240,13 @@ def decode_one_batch(
return input_ids, attention_mask return input_ids, attention_mask
dtype = torch.float32
device = model.llm.device device = model.llm.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
if not params.remove_whisper_encoder_input_length_restriction:
T = 3000
if feature.shape[2] < T:
feature = torch.cat(
[
feature,
torch.zeros(
feature.shape[0], feature.shape[1], T - feature.shape[2]
).to(device, dtype=dtype),
],
2,
)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"] feature_lens = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
messages = [ messages = [
[ [
@ -332,7 +258,10 @@ def decode_one_batch(
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128) input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
generated_ids = model.decode( generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature.to(device),
feature_lens.to(device),
input_ids.to(device, dtype=torch.long),
attention_mask.to(device),
) )
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
@ -471,12 +400,15 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
if params.remove_whisper_encoder_input_length_restriction: speech_encoder_embed = get_encoder_embed(params)
replace_whisper_encoder_forward() speech_encoder = get_encoder_model(params)
load_model_params(
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
)
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
@ -528,6 +460,7 @@ def main():
) )
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder_embed,
speech_encoder, speech_encoder,
llm, llm,
encoder_projector, encoder_projector,

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/encoder_interface.py

View File

@ -1,7 +1,12 @@
from typing import Tuple
import torch import torch
from encoder_interface import EncoderInterface
from torch import nn from torch import nn
from transformers.trainer_pt_utils import LabelSmoother from transformers.trainer_pt_utils import LabelSmoother
from icefall.utils import make_pad_mask
IGNORE_TOKEN_ID = LabelSmoother.ignore_index IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@ -55,11 +60,13 @@ class SPEECH_LLM(nn.Module):
def __init__( def __init__(
self, self,
encoder: nn.Module, encoder_embed: nn.Module,
encoder: EncoderInterface,
llm: nn.Module, llm: nn.Module,
encoder_projector: nn.Module, encoder_projector: nn.Module,
): ):
super().__init__() super().__init__()
self.encoder_embed = encoder_embed
self.encoder = encoder self.encoder = encoder
self.llm = llm self.llm = llm
self.encoder_projector = encoder_projector self.encoder_projector = encoder_projector
@ -192,14 +199,46 @@ class SPEECH_LLM(nn.Module):
return final_embedding, final_attention_mask, final_labels, position_ids return final_embedding, final_attention_mask, final_labels, position_ids
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward( def forward(
self, self,
fbank: torch.Tensor = None, fbank: torch.Tensor,
input_ids: torch.LongTensor = None, fbank_lens: torch.Tensor,
attention_mask: torch.Tensor = None, input_ids: torch.LongTensor,
labels: torch.LongTensor = None, attention_mask: torch.Tensor,
labels: torch.LongTensor,
): ):
encoder_outs = self.encoder(fbank) encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)
@ -229,15 +268,17 @@ class SPEECH_LLM(nn.Module):
def decode( def decode(
self, self,
fbank: torch.Tensor = None, fbank: torch.Tensor,
input_ids: torch.LongTensor = None, fbank_lens: torch.Tensor,
attention_mask: torch.Tensor = None, input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
**kwargs, **kwargs,
): ):
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16) speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
( (
inputs_embeds, inputs_embeds,

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

View File

@ -18,28 +18,17 @@
# limitations under the License. # limitations under the License.
""" """
Usage: Usage:
# fine-tuning with whisper and Qwen2 torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \ --exp-dir ./zipformer_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ --speech-encoder-path-or-name models/zipformer/exp/epoch-999.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./zipformer_llm_zh/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --unfreeze-llm True --use-lora True \
--unfreeze-llm True
""" """
import argparse import argparse
@ -53,7 +42,6 @@ import deepspeed
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from lhotse.cut import Cut from lhotse.cut import Cut
@ -61,10 +49,12 @@ from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zipformer import Zipformer2
from icefall.dist import get_rank, get_world_size from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info from icefall.env import get_env_info
@ -90,14 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--speech-encoder-path-or-name", "--speech-encoder-path-or-name",
type=str, type=str,
default="whisper-large-v2", default="zipformer",
help="Path or name of the speech encoder.", help="Path or name of the speech encoder.",
) )
parser.add_argument( parser.add_argument(
"--encoder-projector-ds-rate", "--encoder-projector-ds-rate",
type=int, type=int,
default=8, default=4,
help="Downsample rate for the encoder projector.", help="Downsample rate for the encoder projector.",
) )
parser.add_argument( parser.add_argument(
@ -121,6 +111,185 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to unfreeze llm during training.", help="Whether to unfreeze llm during training.",
) )
# Zipformer
parser.add_argument(
"--num-encoder-layers",
type=str,
default="2,2,3,4,3,2",
help="Number of zipformer encoder layers per stack, comma separated.",
)
parser.add_argument(
"--downsampling-factor",
type=str,
default="1,2,4,8,4,2",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
"--feedforward-dim",
type=str,
default="512,768,1024,1536,1024,768",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
)
parser.add_argument(
"--num-heads",
type=str,
default="4,4,4,8,4,4",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
default="192,256,384,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--query-head-dim",
type=str,
default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--value-head-dim",
type=str,
default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--pos-head-dim",
type=str,
default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--pos-dim",
type=int,
default="48",
help="Positional-encoding embedding dimension",
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="192,192,256,256,256,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
)
parser.add_argument(
"--cnn-module-kernel",
type=str,
default="31,31,15,15,15,31",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--causal",
type=str2bool,
default=False,
help="If True, use causal version of model.",
)
parser.add_argument(
"--chunk-size",
type=str,
default="16,32,64,-1",
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False",
)
parser.add_argument(
"--left-context-frames",
type=str,
default="64,128,256,-1",
help="Maximum left-contexts for causal training, measured in frames which will "
"be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
)
def load_model_params(ckpt: str, model: nn.Module, module: str, strict: bool = True):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
module (str): Module to be initialized
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
logging.info(f"Loading parameters starting with prefix {module}")
module_prefix = module.strip() + "."
src_keys = [
k[len(module_prefix) :]
for k in src_state_dict.keys()
if k.startswith(module_prefix)
]
dst_keys = [k for k in dst_state_dict.keys()]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(module_prefix + key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
def get_encoder_embed(params: AttributeDict) -> nn.Module:
# encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, (T - 7) // 2, encoder_dims).
# That is, it does two things simultaneously:
# (1) subsampling: T -> (T - 7) // 2
# (2) embedding: num_features -> encoder_dims
# In the normal configuration, we will downsample once more at the end
# by a factor of 2, and most of the encoder stacks will run at a lower
# sampling rate.
encoder_embed = Conv2dSubsampling(
in_channels=params.feature_dim,
out_channels=_to_int_tuple(params.encoder_dim)[0],
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
)
return encoder_embed
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2(
output_downsampling_factor=2,
downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(params.query_head_dim),
pos_head_dim=_to_int_tuple(params.pos_head_dim),
value_head_dim=_to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
)
return encoder
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -154,7 +323,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="whisper_qwen/exp", default="zipformer_llm_zh/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -166,7 +335,7 @@ def get_parser():
type=str, type=str,
default=None, default=None,
help="""The path to the pretrained model if it is not None. Training will help="""The path to the pretrained model if it is not None. Training will
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt start from this model. e.g. ./wenetspeech/ASR/zipformer/exp/epoch-999.pt
""", """,
) )
@ -231,7 +400,6 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"allowed_excess_duration_ratio": 0.1, "allowed_excess_duration_ratio": 0.1,
"subsampling_factor": 2,
"frame_shift_ms": 10, "frame_shift_ms": 10,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
@ -241,6 +409,9 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 5000, "valid_interval": 5000,
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -327,14 +498,13 @@ def compute_loss(
return input_ids, attention_mask, target_ids return input_ids, attention_mask, target_ids
device = next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
messages = [] messages = []
@ -353,7 +523,8 @@ def compute_loss(
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
model_outputs, acc = model( model_outputs, acc = model(
fbank=feature, fbank=feature.to(device),
fbank_lens=feature_lens.to(device),
input_ids=input_ids.to(device), input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device), attention_mask=attention_mask.to(device),
labels=target_ids.to(device), labels=target_ids.to(device),
@ -364,7 +535,6 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
feature_lens = supervisions["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
@ -378,7 +548,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer, tokenizer: AutoTokenizer,
model: nn.Module, model: nn.Module,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
@ -586,10 +756,16 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
replace_whisper_encoder_forward() speech_encoder_embed = get_encoder_embed(params)
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") speech_encoder = get_encoder_model(params)
speech_encoder = whisper_model.encoder load_model_params(
speech_encoder_dim = whisper_model.dims.n_audio_state params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
)
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
for name, param in speech_encoder_embed.named_parameters():
param.requires_grad = False
for name, param in speech_encoder.named_parameters(): for name, param in speech_encoder.named_parameters():
param.requires_grad = False param.requires_grad = False
@ -646,6 +822,7 @@ def run(rank, world_size, args):
) )
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder_embed,
speech_encoder, speech_encoder,
llm, llm,
encoder_projector, encoder_projector,
@ -790,7 +967,6 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}") logging.info(f"Saving batch to {filename}")
torch.save(batch, filename) torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"] features = batch["inputs"]
logging.info(f"features shape: {features.shape}") logging.info(f"features shape: {features.shape}")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py