2025-11-12 15:02:02 +00:00

89 lines
2.3 KiB
Python

import torch
from sentence_transformers import SentenceTransformer
model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id)
print("Original model")
k = 0
for name, param in model.named_parameters():
print(name)
print(param)
k += 1
if k > 1:
break
model_id = "./models/gemma/checkpoint-33246"
model_lora = SentenceTransformer(model_id)
print("LoRA model")
k = 0
for name, param in model_lora.named_parameters():
print(name)
print(param)
k += 1
if k == 3:
a = param
if k == 4:
b = param
if k > 3:
delta = (b @ a) * 2.0
print(delta)
break
print(k)
import torch
import torch
def compare_lora_to_base(model_lora, model_base, lora_scale=1.0):
"""
Compare how much each weight matrix has changed between
the base model and the LoRA-adapted model.
"""
report = []
total_change = 0.0
total_params = 0
has_lora = []
no_lora = []
for name, module in model_lora.named_modules():
# LoRA modules typically have lora_A and lora_B
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
A = module.lora_A["default"].weight.data
B = module.lora_B["default"].weight.data
delta = (B @ A) * lora_scale
# Find matching base layer
try:
base_weight = model_base.get_submodule(name).weight.data
has_lora.append(name)
except Exception:
no_lora.append(name)
new_weight = base_weight + delta
diff = (new_weight - base_weight).abs()
relative_change = diff / (base_weight.abs() + 1e-8)
mean_change = relative_change.mean().item() * 100
report.append((name, mean_change))
total_change += relative_change.sum().item()
total_params += relative_change.numel()
else:
no_lora.append(name)
print("has_lora", has_lora)
print("no_lora", no_lora)
print("lora num", len(has_lora))
print("no lora num", len(no_lora))
overall_change = (total_change / total_params) * 100 if total_params > 0 else 0.0
return report, overall_change
report, overall_change = compare_lora_to_base(model_lora, model, lora_scale=2.0)
print(f"overall_change: {overall_change}")