support zipformer encoder

update

update

update

update

fix

reformat

support infer

update
This commit is contained in:
Yifan Yang 2025-05-08 04:31:34 +00:00 committed by yfyeung
parent 211c01bc1d
commit 489c42b45e
8 changed files with 304 additions and 150 deletions

View File

@ -194,10 +194,10 @@ class SPEECH_LLM(nn.Module):
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
fbank: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
labels: torch.LongTensor,
):
encoder_outs = self.encoder(fbank)

View File

@ -3,6 +3,7 @@
# Fangjun Kuang,
# Wei Kang)
# 2024 Yuekai Zhang
# 2025 Yifan Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -19,31 +20,17 @@
# limitations under the License.
"""
Usage:
# Command for decoding using fine-tuned models:
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
python3 ./zipformer_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--exp-dir zipformer_llm_zh/exp \
--speech-encoder-path-or-name models/zipformer/epoch-999.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--epoch 999 \
--avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
--use-lora True \
--dataset aishell
"""
import argparse
@ -56,15 +43,22 @@ import k2
import torch
import torch.nn as nn
import transformers
import whisper
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import DEFAULT_SPEECH_TOKEN
from train import (
DEFAULT_SPEECH_TOKEN,
_to_int_tuple,
add_model_arguments,
get_encoder_embed,
get_encoder_model,
get_params,
load_model_params,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zipformer import Zipformer2
from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
@ -129,43 +123,6 @@ def average_checkpoints(
return avg
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
help="Path or name of the large language model.",
)
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=True,
help="Whether to use lora fine-tuned llm checkpoint.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -207,17 +164,10 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="whisper/exp",
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
parser.add_argument(
"--dataset",
type=str,
@ -230,15 +180,6 @@ def get_parser():
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -299,28 +240,13 @@ def decode_one_batch(
return input_ids, attention_mask
dtype = torch.float32
device = model.llm.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
if not params.remove_whisper_encoder_input_length_restriction:
T = 3000
if feature.shape[2] < T:
feature = torch.cat(
[
feature,
torch.zeros(
feature.shape[0], feature.shape[1], T - feature.shape[2]
).to(device, dtype=dtype),
],
2,
)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
feature_lens = supervisions["num_frames"]
messages = [
[
@ -332,7 +258,10 @@ def decode_one_batch(
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
feature.to(device),
feature_lens.to(device),
input_ids.to(device, dtype=torch.long),
attention_mask.to(device),
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
@ -471,12 +400,15 @@ def main():
logging.info(f"device: {device}")
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
speech_encoder_embed = get_encoder_embed(params)
speech_encoder = get_encoder_model(params)
load_model_params(
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
)
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
@ -528,6 +460,7 @@ def main():
)
model = SPEECH_LLM(
speech_encoder_embed,
speech_encoder,
llm,
encoder_projector,

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/encoder_interface.py

View File

@ -1,7 +1,12 @@
from typing import Tuple
import torch
from encoder_interface import EncoderInterface
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from icefall.utils import make_pad_mask
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@ -55,11 +60,13 @@ class SPEECH_LLM(nn.Module):
def __init__(
self,
encoder: nn.Module,
encoder_embed: nn.Module,
encoder: EncoderInterface,
llm: nn.Module,
encoder_projector: nn.Module,
):
super().__init__()
self.encoder_embed = encoder_embed
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
@ -192,14 +199,46 @@ class SPEECH_LLM(nn.Module):
return final_embedding, final_attention_mask, final_labels, position_ids
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
fbank: torch.Tensor,
fbank_lens: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
labels: torch.LongTensor,
):
encoder_outs = self.encoder(fbank)
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
speech_features = self.encoder_projector(encoder_outs)
@ -229,15 +268,17 @@ class SPEECH_LLM(nn.Module):
def decode(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
fbank: torch.Tensor,
fbank_lens: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
**kwargs,
):
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

View File

@ -18,28 +18,17 @@
# limitations under the License.
"""
Usage:
# fine-tuning with whisper and Qwen2
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--exp-dir ./zipformer_llm_zh/exp_test \
--speech-encoder-path-or-name models/zipformer/exp/epoch-999.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--deepspeed_config ./zipformer_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--use-lora True \
--unfreeze-llm True
"""
import argparse
@ -53,7 +42,6 @@ import deepspeed
import torch
import torch.nn as nn
import transformers
import whisper
from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from lhotse.cut import Cut
@ -61,10 +49,12 @@ from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zipformer import Zipformer2
from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info
@ -90,14 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
default="zipformer",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
default=4,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
@ -121,6 +111,185 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to unfreeze llm during training.",
)
# Zipformer
parser.add_argument(
"--num-encoder-layers",
type=str,
default="2,2,3,4,3,2",
help="Number of zipformer encoder layers per stack, comma separated.",
)
parser.add_argument(
"--downsampling-factor",
type=str,
default="1,2,4,8,4,2",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
"--feedforward-dim",
type=str,
default="512,768,1024,1536,1024,768",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
)
parser.add_argument(
"--num-heads",
type=str,
default="4,4,4,8,4,4",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
default="192,256,384,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--query-head-dim",
type=str,
default="32",
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--value-head-dim",
type=str,
default="12",
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--pos-head-dim",
type=str,
default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
)
parser.add_argument(
"--pos-dim",
type=int,
default="48",
help="Positional-encoding embedding dimension",
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="192,192,256,256,256,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
)
parser.add_argument(
"--cnn-module-kernel",
type=str,
default="31,31,15,15,15,31",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--causal",
type=str2bool,
default=False,
help="If True, use causal version of model.",
)
parser.add_argument(
"--chunk-size",
type=str,
default="16,32,64,-1",
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False",
)
parser.add_argument(
"--left-context-frames",
type=str,
default="64,128,256,-1",
help="Maximum left-contexts for causal training, measured in frames which will "
"be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
)
def load_model_params(ckpt: str, model: nn.Module, module: str, strict: bool = True):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
module (str): Module to be initialized
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
logging.info(f"Loading parameters starting with prefix {module}")
module_prefix = module.strip() + "."
src_keys = [
k[len(module_prefix) :]
for k in src_state_dict.keys()
if k.startswith(module_prefix)
]
dst_keys = [k for k in dst_state_dict.keys()]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(module_prefix + key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
def get_encoder_embed(params: AttributeDict) -> nn.Module:
# encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, (T - 7) // 2, encoder_dims).
# That is, it does two things simultaneously:
# (1) subsampling: T -> (T - 7) // 2
# (2) embedding: num_features -> encoder_dims
# In the normal configuration, we will downsample once more at the end
# by a factor of 2, and most of the encoder stacks will run at a lower
# sampling rate.
encoder_embed = Conv2dSubsampling(
in_channels=params.feature_dim,
out_channels=_to_int_tuple(params.encoder_dim)[0],
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
)
return encoder_embed
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2(
output_downsampling_factor=2,
downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(params.query_head_dim),
pos_head_dim=_to_int_tuple(params.pos_head_dim),
value_head_dim=_to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
)
return encoder
def get_parser():
parser = argparse.ArgumentParser(
@ -154,7 +323,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="whisper_qwen/exp",
default="zipformer_llm_zh/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -166,7 +335,7 @@ def get_parser():
type=str,
default=None,
help="""The path to the pretrained model if it is not None. Training will
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
start from this model. e.g. ./wenetspeech/ASR/zipformer/exp/epoch-999.pt
""",
)
@ -231,7 +400,6 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"allowed_excess_duration_ratio": 0.1,
"subsampling_factor": 2,
"frame_shift_ms": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
@ -241,6 +409,9 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 5000,
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
"env_info": get_env_info(),
}
)
@ -327,14 +498,13 @@ def compute_loss(
return input_ids, attention_mask, target_ids
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"]
texts = batch["supervisions"]["text"]
messages = []
@ -353,7 +523,8 @@ def compute_loss(
with torch.set_grad_enabled(is_training):
model_outputs, acc = model(
fbank=feature,
fbank=feature.to(device),
fbank_lens=feature_lens.to(device),
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
@ -364,7 +535,6 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
feature_lens = supervisions["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
@ -378,7 +548,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer,
tokenizer: AutoTokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
@ -586,10 +756,16 @@ def run(rank, world_size, args):
logging.info("About to create model")
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
speech_encoder_embed = get_encoder_embed(params)
speech_encoder = get_encoder_model(params)
load_model_params(
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
)
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
for name, param in speech_encoder_embed.named_parameters():
param.requires_grad = False
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
@ -646,6 +822,7 @@ def run(rank, world_size, args):
)
model = SPEECH_LLM(
speech_encoder_embed,
speech_encoder,
llm,
encoder_projector,
@ -790,7 +967,6 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py