いろいろ倉庫

KNIME、EXCEL、R、Pythonなどの備忘録

【Python】分類問題で遊んでみたい。

・お題:機械学習というやつで遊んでみたくて、画像に写っているものが何か分類する問題をやってみようと思った。有名な問題で、手書きの数字画像がなんの数字か判別するモデルを作るやつがあるらしい。やってみたい。

 

・とりあえず、データをロードしてみる。

from sklearn.datasets import load_digits
Data = load_digits(n_class=10)

・Dataの中身が何か分からないので、確認してみる。まず、画像データを確認する。

Data.imagesで以下が返ってくる。

array([[[ 0.,  0.,  5., ...,  1.,  0.,  0.],
        [ 0.,  0., 13., ..., 15.,  5.,  0.],
        [ 0.,  3., 15., ..., 11.,  8.,  0.],
        ...,

・画像データ…なのか?最初の画像データを抜き出してみる。

Data.images[0]で以下が返ってくる。

array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.],
       [ 0.,  0., 13., 15., 10., 15.,  5.,  0.],
       [ 0.,  3., 15.,  2.,  0., 11.,  8.,  0.],
       [ 0.,  4., 12.,  0.,  0.,  8.,  8.,  0.],
       [ 0.,  5.,  8.,  0.,  0.,  9.,  8.,  0.],
       [ 0.,  4., 11.,  0.,  1., 12.,  7.,  0.],
       [ 0.,  2., 14.,  5., 10., 12.,  0.,  0.],
       [ 0.,  0.,  6., 13., 10.,  0.,  0.,  0.]])

・画像のイメージが掴めないので、pandasのdataframeに取り込んで、Heatmapにしてみる。

import pandas as pd
df0=pd.DataFrame(Data.images[0])

df0.style.background_gradient(cmap="Greys" )

・おお。これはたぶんゼロだ。数値の大きさがそのまま濃淡を表しているらしい。pandasのdataframeだとまだ表っぽいので、画像っぽく表示したい。ちょっとやってみる。

import matplotlib.pyplot as plt
plt.imshow(Data.images[0], cmap=plt.cm.binary)

・次に、正解のラベルを確認する。

Data.targetで、以下が返ってくる。

array([0, 1, 2, ..., 8, 9, 8])

なるほど正解ラベルのリストっぽい。len(Data.target)で1797が返ってくる。

機械学習では、データセットすべてでモデルを作成するのではなく、データセットを学習用とテスト用に分けて、学習用データでモデル構築、テスト用データでモデルの性能評価をするらしい。とりあえずデータをカチ割る。

X = Data.images#Xは説明変数。画像データに相当するが、今回は数値のリストのリスト。
y = Data.target #yは目的変数。正解ラベルの数値のリスト。

import numpy as np

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

・今回は、画像データの色の濃淡を表す数値を使って、正解ラベルを予測することにする。画像データの色の濃淡が、数値のリストのリストのままだと扱いにくいので、数値のリストにする。FAXのようなイメージ。

X_train_flat=[n.flatten() for n in X_train]
X_test_flat=[n.flatten() for n in X_test]

・これで、X_train_flatは数値のリストになる。

[array([ 0.,  0.,  0.,  8., 10., 14.,  3.,  0.,  0.,  1., 13., 13.,  9.,
        12.,  8.,  0.,  0.,  6., 16.,  8.,  8., 16.,  4.,  0.,  0.,  5.,
        16., 16., 16.,  9.,  0.,  0.,  0.,  0.,  5.,  8., 14., 12.,  0.,
         0.,  0.,  0.,  0.,  3., 16.,  5.,  0.,  0.,  0.,  0.,  0., 15.,
         8.,  0.,  0.,  0.,  0.,  0.,  1., 12.,  2.,  0.,  0.,  0.]),

......

・今回は、random forestを使って分類モデルを作成する。

from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(max_depth=10)
rf.fit(X_train_flat, y_train)

・accuracyを見てみる。
rf.score(X_test_flat, y_test)で0.975925925925926が返ってきた。なんだか良さそう。

・ここで気になるのは、どういう画像で間違えたのか、だと思う。confusion matrixを描いても良いけれど、今回は間違ったものを抽出した方が、どんな画像が間違えやすいか分かりやすい気がする。

・せっかくなので、probabilityも算出して表にしてみる。

import pandas as pd
df = pd.DataFrame(rf.predict_proba(X_test_flat))

df["pred_test"] = pred_test

df["y_test"] = y_test

df_mistake = df[[not n for n in  df["pred_test"] == df["y_test"]]]

・これで、df_mistakeは以下になる。例えば、index 117は、0か2かで悩んで0と予測したらしい。答えが2だったので、惜しかったかもしれない…?よく見ると、1番probabirityが高いクラスはもちろん間違っているが、2番目に高いクラスは結構正解になっている。

・せっかくなので、どんな画像だったか確認してみる。

for n in range(len(df_mistake)):
    plt.figure(figsize=(1,1))
    plt.imshow(X_test[df_mistake.index[n]], cmap=plt.cm.binary)

……

・納得いかない間違い方をしている画像もあるし、あぁこれは確かに紛らわしい(というか人間でも普通に読み間違えるだろう)という画像もあった。Random Forest以外の手法などでやってみても面白いと思うし、ハイパーパラメータをもっとチューニングしても良いと思う。ちょっと楽しい。

 

おわり。