89 lines
2.3 KiB
Python
89 lines
2.3 KiB
Python
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
model_id = "google/embeddinggemma-300M"
|
|
model = SentenceTransformer(model_id)
|
|
|
|
print("Original model")
|
|
k = 0
|
|
for name, param in model.named_parameters():
|
|
print(name)
|
|
print(param)
|
|
k += 1
|
|
if k > 1:
|
|
break
|
|
|
|
model_id = "./models/gemma/checkpoint-33246"
|
|
model_lora = SentenceTransformer(model_id)
|
|
|
|
print("LoRA model")
|
|
k = 0
|
|
for name, param in model_lora.named_parameters():
|
|
print(name)
|
|
print(param)
|
|
k += 1
|
|
if k == 3:
|
|
a = param
|
|
if k == 4:
|
|
b = param
|
|
|
|
if k > 3:
|
|
delta = (b @ a) * 2.0
|
|
print(delta)
|
|
break
|
|
print(k)
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
def compare_lora_to_base(model_lora, model_base, lora_scale=1.0):
|
|
"""
|
|
Compare how much each weight matrix has changed between
|
|
the base model and the LoRA-adapted model.
|
|
"""
|
|
report = []
|
|
total_change = 0.0
|
|
total_params = 0
|
|
has_lora = []
|
|
no_lora = []
|
|
for name, module in model_lora.named_modules():
|
|
# LoRA modules typically have lora_A and lora_B
|
|
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
|
|
A = module.lora_A["default"].weight.data
|
|
B = module.lora_B["default"].weight.data
|
|
delta = (B @ A) * lora_scale
|
|
|
|
# Find matching base layer
|
|
try:
|
|
base_weight = model_base.get_submodule(name).weight.data
|
|
has_lora.append(name)
|
|
except Exception:
|
|
no_lora.append(name)
|
|
|
|
new_weight = base_weight + delta
|
|
|
|
diff = (new_weight - base_weight).abs()
|
|
relative_change = diff / (base_weight.abs() + 1e-8)
|
|
mean_change = relative_change.mean().item() * 100
|
|
|
|
report.append((name, mean_change))
|
|
total_change += relative_change.sum().item()
|
|
total_params += relative_change.numel()
|
|
|
|
else:
|
|
no_lora.append(name)
|
|
print("has_lora", has_lora)
|
|
print("no_lora", no_lora)
|
|
print("lora num", len(has_lora))
|
|
print("no lora num", len(no_lora))
|
|
overall_change = (total_change / total_params) * 100 if total_params > 0 else 0.0
|
|
return report, overall_change
|
|
|
|
|
|
report, overall_change = compare_lora_to_base(model_lora, model, lora_scale=2.0)
|
|
|
|
print(f"overall_change: {overall_change}")
|