・お題:決定木を、dtreevizというライブラリで、分かりやすく可視化したい。
・決定木は判断根拠が分かりやすいアルゴリズムで、最終的な分類結果だけではなく、その判断プロセスを確認できる点で重宝されることが多い。
・例えばIrisデータセットを決定木で分類するアルゴリズムをsklearnで図示すると、以下のようになる。
・これでも分かるのだけれど、もっと見やすく図示をできるライブラリがあるので、メモしておく。
・ライブラリはdtreevizというやつなのだけれど、Graphvizというソフトがないとエラーが出て怒られるので、とりあえずGraphvizをインストールしておく。詳細は公式サイトを熟読いただきたい。
・以下の公式サイトからインストーラをダウンロードし、インストールを進める。
・最後らへんで、Install Optionsとして、PATHを通すかどうか聞かれる。先にdtreevizを試した際に、PATHを通すように怒られたので、通しておく(下図)。
・なんとなくPCを再起動し、先と同様のグラフをdtreevizで描いてみる。
・データをirisにロードする。
from sklearn.datasets import load_iris
iris=load_iris()
・決定木により分類器を作成する。
from sklearn import tree
clf=tree.DecisionTreeClassifier(max_depth=2)
clf.fit(iris.data,iris.target)
・ちなみに、最初の図は以下のコマンドで出てくる。
from sklearn.tree import plot_tree
plot_tree(clf,
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
filled=True)
・次に、dtreevizで描いてみる。
from dtreeviz.trees import dtreeviz
viz = dtreeviz(clf,
iris.data,
iris.target,
target_name="variety",
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names])
・以下のコマンドで、.svgファイルを作成して開くことで表示される。私の場合はデフォルトでブラウザから.svgファイルが展開された。このファイルは必要に応じて保存すれば良い。
viz.view()
・次に、特定のデータに関して、その分類に影響した情報を図示する。
viz2 = dtreeviz(clf,
iris.data,
iris.target,
target_name="variety",
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
X=iris.data[100])
viz2.view()
・オレンジの変数が分類に大事だったっぽい。
・問題は解釈する能力なんだなぁ、と思った。
おわり。