mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
support zipformer encoder
update update update update fix reformat support infer update
This commit is contained in:
parent
211c01bc1d
commit
489c42b45e
@ -194,10 +194,10 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
fbank: torch.Tensor,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
# 2024 Yuekai Zhang
|
||||
# 2025 Yifan Yang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -19,31 +20,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen models/checkpoint
|
||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
|
||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||
|
||||
python3 ./whisper_llm_zh/decode.py \
|
||||
python3 ./zipformer_llm_zh/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--exp-dir zipformer_llm_zh/exp \
|
||||
--speech-encoder-path-or-name models/zipformer/epoch-999.pt \
|
||||
--llm-path-or-name models/qwen \
|
||||
--epoch 999 --avg 1 \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --dataset aishell
|
||||
--use-lora True \
|
||||
--dataset aishell
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -56,15 +43,22 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from multi_dataset import MultiDataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from train import (
|
||||
DEFAULT_SPEECH_TOKEN,
|
||||
_to_int_tuple,
|
||||
add_model_arguments,
|
||||
get_encoder_embed,
|
||||
get_encoder_model,
|
||||
get_params,
|
||||
load_model_params,
|
||||
)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
@ -129,43 +123,6 @@ def average_checkpoints(
|
||||
return avg
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--llm-path-or-name",
|
||||
type=str,
|
||||
default="/workspace/asr/Qwen1.5-0.5B-Chat",
|
||||
help="Path or name of the large language model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speech-encoder-path-or-name",
|
||||
type=str,
|
||||
default="whisper-large-v2",
|
||||
help="Path or name of the speech encoder.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-projector-ds-rate",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-lora",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use lora fine-tuned llm checkpoint.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -207,17 +164,10 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper/exp",
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove-whisper-encoder-input-length-restriction",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
@ -230,15 +180,6 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -299,28 +240,13 @@ def decode_one_batch(
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
dtype = torch.float32
|
||||
device = model.llm.device
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
if not params.remove_whisper_encoder_input_length_restriction:
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat(
|
||||
[
|
||||
feature,
|
||||
torch.zeros(
|
||||
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
||||
).to(device, dtype=dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
feature_lens = supervisions["num_frames"]
|
||||
|
||||
messages = [
|
||||
[
|
||||
@ -332,7 +258,10 @@ def decode_one_batch(
|
||||
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
|
||||
|
||||
generated_ids = model.decode(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
feature.to(device),
|
||||
feature_lens.to(device),
|
||||
input_ids.to(device, dtype=torch.long),
|
||||
attention_mask.to(device),
|
||||
)
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
@ -471,12 +400,15 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
speech_encoder_embed = get_encoder_embed(params)
|
||||
speech_encoder = get_encoder_model(params)
|
||||
load_model_params(
|
||||
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
|
||||
)
|
||||
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
|
||||
|
||||
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
|
||||
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
if params.use_flash_attn:
|
||||
@ -528,6 +460,7 @@ def main():
|
||||
)
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder_embed,
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/encoder_interface.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/encoder_interface.py
|
@ -1,7 +1,12 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
from torch import nn
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
|
||||
@ -55,11 +60,13 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: nn.Module,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
llm: nn.Module,
|
||||
encoder_projector: nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.llm = llm
|
||||
self.encoder_projector = encoder_projector
|
||||
@ -192,14 +199,46 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoder outputs.
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
|
||||
Returns:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
"""
|
||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
fbank: torch.Tensor,
|
||||
fbank_lens: torch.Tensor,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
@ -229,15 +268,17 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
def decode(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
fbank: torch.Tensor,
|
||||
fbank_lens: torch.Tensor,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
|
||||
|
||||
encoder_outs = self.encoder(fbank)
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
speech_features = speech_features.to(torch.float16)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
(
|
||||
inputs_embeds,
|
||||
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/scaling.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/scaling.py
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/subsampling.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/subsampling.py
|
@ -18,28 +18,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# fine-tuning with whisper and Qwen2
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
torchrun --nproc_per_node 8 ./zipformer_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--exp-dir ./zipformer_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/zipformer/exp/epoch-999.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--deepspeed_config ./zipformer_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
--use-lora True \
|
||||
--unfreeze-llm True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -53,7 +42,6 @@ import deepspeed
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from lhotse.cut import Cut
|
||||
@ -61,10 +49,12 @@ from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
from multi_dataset import MultiDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.dist import get_rank, get_world_size
|
||||
from icefall.env import get_env_info
|
||||
@ -90,14 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--speech-encoder-path-or-name",
|
||||
type=str,
|
||||
default="whisper-large-v2",
|
||||
default="zipformer",
|
||||
help="Path or name of the speech encoder.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-projector-ds-rate",
|
||||
type=int,
|
||||
default=8,
|
||||
default=4,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -121,6 +111,185 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Whether to unfreeze llm during training.",
|
||||
)
|
||||
|
||||
# Zipformer
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="2,2,3,4,3,2",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--downsampling-factor",
|
||||
type=str,
|
||||
default="1,2,4,8,4,2",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feedforward-dim",
|
||||
type=str,
|
||||
default="512,768,1024,1536,1024,768",
|
||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-heads",
|
||||
type=str,
|
||||
default="4,4,4,8,4,4",
|
||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="192,256,384,512,384,256",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=str,
|
||||
default="32",
|
||||
help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--value-head-dim",
|
||||
type=str,
|
||||
default="12",
|
||||
help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-head-dim",
|
||||
type=str,
|
||||
default="4",
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default="48",
|
||||
help="Positional-encoding embedding dimension",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=str,
|
||||
default="192,192,256,256,256,192",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cnn-module-kernel",
|
||||
type=str,
|
||||
default="31,31,15,15,15,31",
|
||||
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use causal version of model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=str,
|
||||
default="16,32,64,-1",
|
||||
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
||||
" Must be just -1 if --causal=False",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context-frames",
|
||||
type=str,
|
||||
default="64,128,256,-1",
|
||||
help="Maximum left-contexts for causal training, measured in frames which will "
|
||||
"be converted to a number of chunks. If splitting into chunks, "
|
||||
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
||||
)
|
||||
|
||||
|
||||
def load_model_params(ckpt: str, model: nn.Module, module: str, strict: bool = True):
|
||||
"""Load model params from checkpoint
|
||||
|
||||
Args:
|
||||
ckpt (str): Path to the checkpoint
|
||||
model (nn.Module): model to be loaded
|
||||
module (str): Module to be initialized
|
||||
|
||||
"""
|
||||
logging.info(f"Loading checkpoint from {ckpt}")
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
|
||||
src_state_dict = checkpoint["model"]
|
||||
dst_state_dict = model.state_dict()
|
||||
logging.info(f"Loading parameters starting with prefix {module}")
|
||||
module_prefix = module.strip() + "."
|
||||
src_keys = [
|
||||
k[len(module_prefix) :]
|
||||
for k in src_state_dict.keys()
|
||||
if k.startswith(module_prefix)
|
||||
]
|
||||
dst_keys = [k for k in dst_state_dict.keys()]
|
||||
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
||||
for key in src_keys:
|
||||
dst_state_dict[key] = src_state_dict.pop(module_prefix + key)
|
||||
|
||||
model.load_state_dict(dst_state_dict, strict=strict)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
# encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, (T - 7) // 2, encoder_dims).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> (T - 7) // 2
|
||||
# (2) embedding: num_features -> encoder_dims
|
||||
# In the normal configuration, we will downsample once more at the end
|
||||
# by a factor of 2, and most of the encoder stacks will run at a lower
|
||||
# sampling rate.
|
||||
encoder_embed = Conv2dSubsampling(
|
||||
in_channels=params.feature_dim,
|
||||
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
)
|
||||
return encoder_embed
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Zipformer2(
|
||||
output_downsampling_factor=2,
|
||||
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||
pos_dim=params.pos_dim,
|
||||
num_heads=_to_int_tuple(params.num_heads),
|
||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -154,7 +323,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="whisper_qwen/exp",
|
||||
default="zipformer_llm_zh/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -166,7 +335,7 @@ def get_parser():
|
||||
type=str,
|
||||
default=None,
|
||||
help="""The path to the pretrained model if it is not None. Training will
|
||||
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
|
||||
start from this model. e.g. ./wenetspeech/ASR/zipformer/exp/epoch-999.pt
|
||||
""",
|
||||
)
|
||||
|
||||
@ -231,7 +400,6 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"allowed_excess_duration_ratio": 0.1,
|
||||
"subsampling_factor": 2,
|
||||
"frame_shift_ms": 10,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
@ -241,6 +409,9 @@ def get_params() -> AttributeDict:
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 5000,
|
||||
# parameters for zipformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -327,14 +498,13 @@ def compute_loss(
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
feature = batch["inputs"]
|
||||
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"]
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
messages = []
|
||||
@ -353,7 +523,8 @@ def compute_loss(
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
model_outputs, acc = model(
|
||||
fbank=feature,
|
||||
fbank=feature.to(device),
|
||||
fbank_lens=feature_lens.to(device),
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
@ -364,7 +535,6 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
feature_lens = supervisions["num_frames"]
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
@ -378,7 +548,7 @@ def compute_loss(
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
tokenizer: AutoTokenizer,
|
||||
model: nn.Module,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
@ -586,10 +756,16 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
replace_whisper_encoder_forward()
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
speech_encoder_embed = get_encoder_embed(params)
|
||||
speech_encoder = get_encoder_model(params)
|
||||
load_model_params(
|
||||
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
|
||||
)
|
||||
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
|
||||
|
||||
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
|
||||
for name, param in speech_encoder_embed.named_parameters():
|
||||
param.requires_grad = False
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@ -646,6 +822,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder_embed,
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
@ -790,7 +967,6 @@ def display_and_save_batch(
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/zipformer.py
Symbolic link
1
egs/speech_llm/ASR_LLM/zipformer_llm_zh/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user