いろいろ倉庫

KNIME、EXCEL、R、Pythonなどの備忘録

【Python】dtreevizで決定木を可視化したい

・お題:決定木を、dtreevizというライブラリで、分かりやすく可視化したい。

 

・決定木は判断根拠が分かりやすいアルゴリズムで、最終的な分類結果だけではなく、その判断プロセスを確認できる点で重宝されることが多い。

・例えばIrisデータセットを決定木で分類するアルゴリズムをsklearnで図示すると、以下のようになる。

・これでも分かるのだけれど、もっと見やすく図示をできるライブラリがあるので、メモしておく。

・ライブラリはdtreevizというやつなのだけれど、Graphvizというソフトがないとエラーが出て怒られるので、とりあえずGraphvizをインストールしておく。詳細は公式サイトを熟読いただきたい。

・以下の公式サイトからインストーラをダウンロードし、インストールを進める。

www.graphviz.org

・最後らへんで、Install Optionsとして、PATHを通すかどうか聞かれる。先にdtreevizを試した際に、PATHを通すように怒られたので、通しておく(下図)。

・なんとなくPCを再起動し、先と同様のグラフをdtreevizで描いてみる。

・データをirisにロードする。

from sklearn.datasets import load_iris
iris=load_iris()

・決定木により分類器を作成する。

from sklearn import tree
clf=tree.DecisionTreeClassifier(max_depth=2)
clf.fit(iris.data,iris.target)

・ちなみに、最初の図は以下のコマンドで出てくる。

from sklearn.tree import plot_tree
plot_tree(clf,
          feature_names=iris.feature_names,
          class_names=[str(i) for i in iris.target_names],
          filled=True)

・次に、dtreevizで描いてみる。

from dtreeviz.trees import dtreeviz

viz = dtreeviz(clf,
               iris.data,
               iris.target,
               target_name="variety",
               feature_names=iris.feature_names,
               class_names=[str(i) for i in iris.target_names])

・以下のコマンドで、.svgファイルを作成して開くことで表示される。私の場合はデフォルトでブラウザから.svgファイルが展開された。このファイルは必要に応じて保存すれば良い。

viz.view() 

・次に、特定のデータに関して、その分類に影響した情報を図示する。
viz2 = dtreeviz(clf,
               iris.data,
               iris.target,
               target_name="variety",
               feature_names=iris.feature_names,
               class_names=[str(i) for i in iris.target_names],
               X=iris.data[100])

viz2.view() 

・オレンジの変数が分類に大事だったっぽい。
・問題は解釈する能力なんだなぁ、と思った。

 

おわり。