mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
77 lines
2.5 KiB
Python
Executable File
77 lines
2.5 KiB
Python
Executable File
import os
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
|
|
out_dir='outputs/'
|
|
|
|
org = torch.load('pruned_transducer_stateless_d2v_v2/d2v-T/epoch-27.pt')
|
|
ft = torch.load('pruned_transducer_stateless_d2v_v2/d2v-T-LJft-oracle-epoch100/best-valid-loss.pt')
|
|
pft = torch.load('pruned_transducer_stateless_d2v_v2/d2v-T-LJft-pseudo-trained/best-valid-loss.pt')
|
|
|
|
org_state_dict, ft_state_dict, pft_state_dict = org['model'], ft['model'], pft['model']
|
|
|
|
# # Define three model parameters
|
|
# for (n1, p1), (n2, p2), (n3, p3) in zip(org.named_parameters(), ft.named_parameters(), pft.named_parameters()):
|
|
# assert n1 == n2 == n3
|
|
# org_state_dict[n1] = p1
|
|
# ft_state_dict[n1] = p2
|
|
# pft_state_dict[n1] = p3
|
|
|
|
for name in org_state_dict.keys():
|
|
# Define the weight names
|
|
weights = [org_state_dict[name], ft_state_dict[name], pft_state_dict[name]]
|
|
|
|
# Define the x-axis labels
|
|
x = ['org vs ft', 'org vs pft', 'ft vs pft']
|
|
abs_diff, rel_diff, cos_sim = [], [], []
|
|
|
|
i = 0
|
|
k = 0
|
|
for j in range(i+1,3):
|
|
if weights[i].dim() > 1:
|
|
wi = weights[i].view(-1)
|
|
wj = weights[j].view(-1)
|
|
else:
|
|
wi = weights[i]
|
|
wj = weights[j]
|
|
|
|
# Compute absolute difference
|
|
abs_diff.append(torch.abs(wi - wj).sum(-1).cpu().numpy())
|
|
|
|
# Compute relative difference
|
|
rel_diff.append(abs_diff[k] / torch.max(torch.abs(wi).sum(-1), torch.abs(wj).sum(-1)).cpu().numpy())
|
|
|
|
# Compute cosine similarity
|
|
cos_sim.append(torch.nn.functional.cosine_similarity(wi, wj, dim=0).cpu().numpy())
|
|
|
|
k += 1
|
|
|
|
i, j = 1, 2
|
|
if weights[i].dim() > 1:
|
|
wi = weights[i].view(-1)
|
|
wj = weights[j].view(-1)
|
|
else:
|
|
wi = weights[i]
|
|
wj = weights[j]
|
|
|
|
# Compute absolute difference
|
|
abs_diff.append(torch.abs(wi - wj).sum(-1).cpu().numpy())
|
|
|
|
# Compute relative difference
|
|
rel_diff.append(abs_diff[2] / torch.max(torch.abs(wi).sum(-1), torch.abs(wj).sum(-1)).cpu().numpy())
|
|
|
|
# Compute cosine similarity
|
|
cos_sim.append(torch.nn.functional.cosine_similarity(wi, wj, dim=0).cpu().numpy())
|
|
|
|
if abs_diff[1] > 0.001:
|
|
print(name)
|
|
|
|
for k, metric in enumerate([abs_diff, rel_diff, cos_sim]):
|
|
y = metric
|
|
# Plot results as a bar graph
|
|
plt.bar(x, y)
|
|
|
|
if not os.path.exists(out_dir + name):
|
|
os.makedirs(out_dir + name, exist_ok=True)
|
|
plt.savefig(out_dir + name + "/" + ['abs_diff', 'rel_diff', 'cos_sim'][k] + ".png")
|
|
plt.close() |