- from sklearn.linear_model import LinearRegression
- X = df3[['人口']]
- y = df3['平均年収']
- # モデルの作成
- model = LinearRegression()
- model.fit(X, y)
- print("傾き: %f" % model.coef_)
- print("切片: %f" % model.intercept_)
- print()
- print("決定係数: %f" % model.score(X, y))
傾き: 0.124986
切片: 4101341.031155
決定係数: 0.605080
sklearnを使うだけで、簡単に単回帰分析出来ます。
predictを使い回帰直線を引きます。
人口と平均年収の関係はありそうですね。
- predict = model.predict(X)
- plt.plot(x, predict, color="coral")
- #! /usr/bin/env python
- # -*- coding:utf-8 -*-
- #
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import japanize_matplotlib
- from sklearn.linear_model import LinearRegression
- #日本円表示を数値に
- def currencyToNumber(val):
- # 万と千以下に分ける
- man, sen = val.split("万")
- #円を取り除き,千以下を0埋め
- sen = sen.replace("円", "")
- sen = sen.zfill(4)
- man += sen
- return int(man)
- # 数値のカンマを取り除く
- def dropComma(val):
- val = val.replace(",", "")
- return float(val)
- data1 = "../data/prefecture1.csv"
- data2 = "../data/income_pref.csv"
- df1 = pd.read_csv(data1)
- df2 = pd.read_csv(data2)
- print(df1.head())
- print()
- print(df2.head())
- print()
- df3 = pd.merge(df1, df2)
- print(df3.head())
- print()
- df3['平均年収'] = df3['平均年収'].apply(currencyToNumber)
- df3['人口'] = df3['人口'].apply(dropComma)
- x = df3['人口']
- X = df3[['人口']]
- y = df3['平均年収']
- # モデルの作成
- model = LinearRegression()
- model.fit(X, y)
- print("傾き: %f" % model.coef_)
- print("切片: %f" % model.intercept_)
- print()
- print("決定係数: %f" % model.score(X, y))
- pref_name = df3['都道府県']
- plt.figure(figsize=(16, 10))
- predict = model.predict(X)
- plt.title("人口と平均年収のグラフ", fontsize=24)
- plt.scatter(x, y)
- plt.grid()
- plt.xlabel("人口(人)", fontsize=18)
- plt.ylabel("平均年収(円)", fontsize=18)
- plt.plot(x, predict, color="coral")
- for i, pref in enumerate(pref_name):
- plt.annotate(pref, (x[i], y[i]))
- nums = plt.gca().get_yticks()
- plt.gca().set_yticklabels(['{:,.0f}'.format(i) for i in nums])
- nums = plt.gca().get_xticks()
- plt.gca().set_xticklabels(['{:,.0f}'.format(i) for i in nums])
- out = "img/prefecture_data1.png"
- plt.savefig(out)
- plt.show()