2022年8月28日日曜日

線形回帰を求める(2)

前回は散布図で、都道府県別人口と平均年収をプロットしました。 後は線形回帰分析するだけです。 線形回帰はsklearnライブラリを使用することで簡単に分析できます。
from sklearn.linear_model import LinearRegression

X = df3[['人口']]
y = df3['平均年収']

# モデルの作成
model = LinearRegression()

model.fit(X, y)


print("傾き: %f" % model.coef_)
print("切片: %f" % model.intercept_)
print()
print("決定係数: %f" % model.score(X, y))


y = ax + b 

傾き: 0.124986 
切片: 4101341.031155 

決定係数: 0.605080


sklearnを使うだけで、簡単に単回帰分析出来ます。

predictを使い回帰直線を引きます。

人口と平均年収の関係はありそうですね。

predict = model.predict(X)

plt.plot(x, predict, color="coral")

全コード
#! /usr/bin/env python

# -*- coding:utf-8 -*-

#
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
from sklearn.linear_model import LinearRegression

#日本円表示を数値に
def currencyToNumber(val):
    # 万と千以下に分ける
    man, sen = val.split("万")

    #円を取り除き,千以下を0埋め
    sen = sen.replace("円", "")
    sen = sen.zfill(4)

    man += sen
    return int(man)
    

# 数値のカンマを取り除く
def dropComma(val):
    val = val.replace(",", "")

    return float(val)

data1 = "../data/prefecture1.csv"
data2 = "../data/income_pref.csv"

df1 = pd.read_csv(data1)
df2 = pd.read_csv(data2)

print(df1.head())
print()
print(df2.head())
print()

df3 = pd.merge(df1, df2)

print(df3.head())
print()

df3['平均年収'] = df3['平均年収'].apply(currencyToNumber)
df3['人口'] = df3['人口'].apply(dropComma)

x = df3['人口']
X = df3[['人口']]
y = df3['平均年収']


# モデルの作成
model = LinearRegression()

model.fit(X, y)


print("傾き: %f" % model.coef_)
print("切片: %f" % model.intercept_)
print()
print("決定係数: %f" % model.score(X, y))

pref_name = df3['都道府県']

plt.figure(figsize=(16, 10))

predict = model.predict(X)


plt.title("人口と平均年収のグラフ", fontsize=24)
plt.scatter(x, y)
plt.grid()

plt.xlabel("人口(人)", fontsize=18)
plt.ylabel("平均年収(円)", fontsize=18)

plt.plot(x, predict, color="coral")

for i, pref in enumerate(pref_name):
    plt.annotate(pref, (x[i], y[i]))

nums = plt.gca().get_yticks()
plt.gca().set_yticklabels(['{:,.0f}'.format(i) for i in nums])
nums = plt.gca().get_xticks()
plt.gca().set_xticklabels(['{:,.0f}'.format(i) for i in nums])

out = "img/prefecture_data1.png"
plt.savefig(out)
plt.show()

0 件のコメント:

コメントを投稿