いろいろ倉庫

KNIME、EXCEL、R、Pythonなどの備忘録

【Python】相関ネットワークの図を描きたい②

・お題:いろいろな変数の数値データが入った表を入手した。変数の相関関係からグラフを描きたい。

 

・やりたいことは先日と同じだけれど、今回はnetworkxの描画をもう少し工夫してみた。以下のサイトを参考にさせて頂いた。

qiita.com

・とりあえず、データセットを作成する。7個の乱数のセットを一つの変数と想定し、20個の変数を作った。

import numpy as np
np.random.seed(seed=0)
   
Dic1=dict*1 for n in range(20))

import pandas as pd
df=pd.DataFrame(Dic1)

・次に、相関係数を求めた。

df_corr=df.corr()

・次に、2つの変数のペアと相関係数のタプルをエッジリストとして作成。

edges=[]
for n in range(len(df.columns)):
    for m in range(n+1,len(df.columns)):
        edges.append*2

・これでedgesは以下のようになった。

[('val_0', 'val_1', 0.057261314415455164),
 ('val_0', 'val_2', -0.5411289176512385),
 ('val_0', 'val_3', -0.4616430051289798),
 ('val_0', 'val_4', 0.28284782289229016),
 ('val_0', 'val_5', -0.12345277973531268),
 ('val_0', 'val_6', -0.2027262193814838),

......

相関係数の絶対値が小さいものは無視したいので、エッジリストから除外することにする。今回は、相関係数の絶対値が0.6を超えるエッジだけとって来ることにする。

Lis1=[ed for ed in edges if abs(ed[-1])>0.6]

・これでLis1は以下になった。

[('val_0', 'val_11', -0.6511401663173221),
 ('val_0', 'val_16', -0.7403557671505099),
 ('val_1', 'val_14', -0.6338529765771386),
 ('val_1', 'val_18', -0.8835625773960867),
......

相関係数の絶対値をweightとして利用するため、読み込み用リストのLis2を作成。

Lis2=[(n[0],n[1],abs(n[2])) for n in Lis1]

・これをネットワークとして読み込む。

import networkx as nx
G = nx.Graph()
G.add_weighted_edges_from(Lis2)

・これでグラフを作成する。

%matplotlib inline
import matplotlib.pyplot as plt

plt.figure(figsize=(7,7))

#posでレイアウトを指定。
pos = nx.spring_layout(G, k=0.4)

#ノードの大きさは、次数(ノードから伸びているエッジの数)を利用して設定

sizes=[100*(n**2) for n in list(dict(G.degree()).values())]
nx.draw_networkx_nodes(G, pos, node_color="white",alpha=0.7, node_size=sizes, edgecolors="gray")

#エッジの幅は相関係数の絶対値を利用して設定

widths=[(m*2)**2 for m in [n[-1] for n in Lis1]]

#エッジの色は相関係数を利用して設定。相関係数が正なら赤、負なら青、絶対値が大きいほど濃くなる。ついでに透過度も相関係数の絶対値で設定。

colors=[(m+1)/2 for m in [n[-1] for n in Lis1]]
alphas=[abs(m-0.5)*2 for m in colors]
nx.draw_networkx_edges(G, pos, alpha=alphas, width=widths, edge_color=colors, edge_cmap=plt.cm.bwr)

#ラベルを表示。

nx.draw_networkx_labels(G, pos, font_size=10, font_family="Yu Gothic")


plt.show()

・それっぽくなった。今回地味に苦労したのが、pos = nx.spring_layout(G, k=0.4)のkの設定。kの値によってグラフの形が変わる。
k=0.1

k=1

k=10

・可視化が目的なので、見てわかりやすいようにパラメータを調整する必要がありそう。

 

おわり。

 

 

 

*1:f"val_{n}",np.random.rand(7

*2:df_corr.index[n],df_corr.columns[m],df_corr.iloc[n,m]