BETA

fastai動かしてみたLesson4② テーブルデータの機械学習

投稿日:2019-05-04
最終更新:2019-05-05

今回はLesson4 https://course.fast.ai/videos/?lesson=4

notebookはこちら https://github.com/fastai/course-v3/tree/master/nbs/dl1

自然言語処理、協調フィルタリング、テーブルデータの機械学習について学んでいく

テーブルデータ

テーブルデータとは、Excelの様な表のデータといえば想像しやすいかもしれない

使うデータは収入などが含まれるcsvファイル
こんな感じ

今回はこれを使って機械学習を行い、収入が5万ドル以上稼いでいるかの予測を行う

from fastai import *  
from fastai.tabular import *  

使うデータはpandasで扱えるものだとする

path = untar_data(URLs.ADULT_SAMPLE)  
df = pd.read_csv(path/'adult.csv')  

独立変数、カテゴリの名前、連続変数の名前、プロセッサを決める

dep_var = '>=50k'  
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']  
cont_names = ['age', 'fnlwgt', 'education-num']  
procs = [FillMissing, Categorify, Normalize]

テストデータと訓練データを用意
いつも通りDatablockAPIを使う

test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)  

data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)  
                           .split_by_idx(list(range(800,1000)))  
                           .label_from_df(cols=dep_var)  
                           .add_test(test, label=0)  
                           .databunch())

今回作成した物はTabularListであり、データフレームから作成する

検証データはインデックスの800から1000の物を使う

従って以下の事をTabularListに渡す必要がある
・何のデータフレームか
・モデルと中間ステップを保存するためのpathが何か
・categorical変数と連続変数が何なのか

変数いろいろ

独立変数は予測するために使っているもので、今回のデータで言うと年齢や婚姻状況、職業などのこと
これらのデータは時にはバイナリ(真か偽)のデータであるかもしれない
そこで、何かしらの可能性を持つ選択肢の事をCategorical variableという(要は数値でなくてカテゴリ的な変数の事:真偽とか曜日とか…)
そこで、categorical変数を連続変数を使用するモデルにモデル化を行うには、ニューラルネットでは別の手法を使う必要がある
それをEmbedding(埋め込み)といいのちに説明する

プロセッサ

画像で言うと、写真を回転させたり明るくさせたり正規化する事などを行い、データを変形させた

表形式のデータではそれらの事を事前に行う
例えば、
FillMissing:欠損値に対して何らかの処理を施す
Categorify:Categorical変数を見つけてそれをPandasのカテゴリに変える
Normalize:連続関数からそれらの平均を引き、それらの標準偏差で割る事で、0-1の範囲に収まるように事前に正規化を行う

中身を見るとこんな感じ(さっきも張った)

data.show_batch(rows=10)  

学習器を作る

ここは大体いつも通り
layers=[200, 100]に関しては、今回の講義の最後か次の講義の初めに説明するらしい

learn = tabular_learner(data, layers=[200,100], metrics=accuracy)  
learn.fit(1, 1e-2)

Total time: 00:03
epoch train_loss valid_loss accuracy
1 0.362837 0.413169 0.785000 (00:03)

これで適当にpredictをいつも通り行えばヨシ

技術ブログをはじめよう Qrunch(クランチ)は、プログラマの技術アプトプットに特化したブログサービスです
駆け出しエンジニアからエキスパートまで全ての方々のアウトプットを歓迎しております!
or 外部アカウントで 登録 / ログイン する
クランチについてもっと詳しく

この記事が掲載されているブログ

大学生の書きなぐりブログ 間違ってる事も書いてるので自己責任で勉強しましょう

よく一緒に読まれる記事

0件のコメント

ブログ開設 or ログイン してコメントを送ってみよう
目次をみる
技術ブログをはじめよう Qrunch(クランチ)は、プログラマの技術アプトプットに特化したブログサービスです
or 外部アカウントではじめる
10秒で技術ブログが作れます!