icefall/egs/LJSpeech/ASR/utils/plot_similarities.py
2023-01-19 11:27:19 +09:00

77 lines
2.5 KiB
Python
Executable File

import os
import torch
import matplotlib.pyplot as plt
out_dir='outputs/'
org = torch.load('pruned_transducer_stateless_d2v_v2/d2v-T/epoch-27.pt')
ft = torch.load('pruned_transducer_stateless_d2v_v2/d2v-T-LJft-oracle-epoch100/best-valid-loss.pt')
pft = torch.load('pruned_transducer_stateless_d2v_v2/d2v-T-LJft-pseudo-trained/best-valid-loss.pt')
org_state_dict, ft_state_dict, pft_state_dict = org['model'], ft['model'], pft['model']
# # Define three model parameters
# for (n1, p1), (n2, p2), (n3, p3) in zip(org.named_parameters(), ft.named_parameters(), pft.named_parameters()):
# assert n1 == n2 == n3
# org_state_dict[n1] = p1
# ft_state_dict[n1] = p2
# pft_state_dict[n1] = p3
for name in org_state_dict.keys():
# Define the weight names
weights = [org_state_dict[name], ft_state_dict[name], pft_state_dict[name]]
# Define the x-axis labels
x = ['org vs ft', 'org vs pft', 'ft vs pft']
abs_diff, rel_diff, cos_sim = [], [], []
i = 0
k = 0
for j in range(i+1,3):
if weights[i].dim() > 1:
wi = weights[i].view(-1)
wj = weights[j].view(-1)
else:
wi = weights[i]
wj = weights[j]
# Compute absolute difference
abs_diff.append(torch.abs(wi - wj).sum(-1).cpu().numpy())
# Compute relative difference
rel_diff.append(abs_diff[k] / torch.max(torch.abs(wi).sum(-1), torch.abs(wj).sum(-1)).cpu().numpy())
# Compute cosine similarity
cos_sim.append(torch.nn.functional.cosine_similarity(wi, wj, dim=0).cpu().numpy())
k += 1
i, j = 1, 2
if weights[i].dim() > 1:
wi = weights[i].view(-1)
wj = weights[j].view(-1)
else:
wi = weights[i]
wj = weights[j]
# Compute absolute difference
abs_diff.append(torch.abs(wi - wj).sum(-1).cpu().numpy())
# Compute relative difference
rel_diff.append(abs_diff[2] / torch.max(torch.abs(wi).sum(-1), torch.abs(wj).sum(-1)).cpu().numpy())
# Compute cosine similarity
cos_sim.append(torch.nn.functional.cosine_similarity(wi, wj, dim=0).cpu().numpy())
if abs_diff[1] > 0.001:
print(name)
for k, metric in enumerate([abs_diff, rel_diff, cos_sim]):
y = metric
# Plot results as a bar graph
plt.bar(x, y)
if not os.path.exists(out_dir + name):
os.makedirs(out_dir + name, exist_ok=True)
plt.savefig(out_dir + name + "/" + ['abs_diff', 'rel_diff', 'cos_sim'][k] + ".png")
plt.close()