「大人の教養・知識・気付き」を伸ばすブログ

一流の大人(ビジネスマン、政治家、リーダー…)として知っておきたい、教養・社会動向を意外なところから取り上げ学ぶことで“気付く力”を伸ばすブログです。目下、データ分析・語学に力点を置いています。今月(2022年10月)からは多忙につき、日々の投稿数を減らします。

MENU

Juliaを使ってみる(11/22):線形回帰を使ってみる

1. 線形回帰

1.1 線形回帰のシミュレーション

1.1.1 線形回帰
################
### 線形回帰 ###
################

using Distributions
using PyPlot

# グラフの諸設定を行う関数
function set_options(ax, xlabel, ylabel, title;
                      grid = true, gridy = false, legend = false)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    if grid
        if gridy
            ax.grid(axis = "y")
        else
            ax.grid()
        end
    end
    legend && ax.legend()
end

### 線形回帰モデルに従うデータセットを生成
function generate_lm(X, σ, μ₁,μ₂,σ₁,σ₂)
    w₁ = rand(Normal(μ₁, σ₁))
    w₂ = rand(Normal(μ₂, σ₂))
    f(x) = w₁*x + w₂
    Y = rand.(Normal.(f.(X), σ))
    Y, f, w₁, w₂
end

### シミュレーション
σ = 1.0
μ₁ = 0.0
μ₂ = 0.0
σ₁ = 10.0
σ₂ = 10.0
X = range(-5,5, length = 10)#[-1.0, 0.5, 0.0, 0.5, 1.0]

# 可視化する範囲
xs = range(-5, 5, length = 100)

fig, axes = subplots(2, 3, sharey = true, figsize = (12,6))

for ax in axes
    # 関数f, 出力Yの生成
    Y, f, w₁, w₂ = generate_lm(X,σ,μ₁, μ₂,σ₁,σ₂)
    
    # 生成された直線とYをプロット
    ax.plot(xs, f.(xs), label = "simulated function")

    ax.scatter(X, Y, label = "simulated data")
    
    set_options(ax, "x", "y", "data(N = $(length(X)))", legend = true)
end

tight_layout()



1.1.2 平均を変えてのシミュレーション
######################################
### 平均を変えてのシミュレーション ###
######################################

σ = 1.0

# 平均のリスト
μ₁= [-20.0, 0.0, 20.0]
μ₂= [-20.0, 0.0, 20.0]

# 標準偏差は固定
σ₁ = 10.0
σ₂ = 10.0


fig, axes = subplots(length(μ₁), length(μ₂), sharey = true, figsize = (12,12))
for (i, μ₃) in enumerate(μ₁)
    for (j,μ₄) in enumerate(μ₂)
        fs = [generate_lm(X, σ, μ₃, μ₄, σ₁, σ₂)[2] for _ in 1:100]
        
        for f in fs
            axes[i, j].plot(xs, f.(xs), "g", alpha = 0.1)
        end
        
        axes[i, j].set_xlim(extrema(xs))
        set_options(axes[i, j], "x", "y", "μ₁ =$(μ₃), μ₂ =$(μ₄)")
    end
end

tight_layout()



1.1.3 標準偏差を変えてのシミュレーション
##########################################
### 標準偏差を変えてのシミュレーション ###
##########################################


# 平均のリスト
σ₁_list = [1.0, 10.0, 20.0]
σ₂_list = [1.0, 10.0, 20.0]

# 標準偏差は固定
μ₁ = 0.0
μ₂ = 0.0


fig, axes = subplots(length(σ₁_list), length(σ₂_list), sharey = true, figsize = (12,12))

for (i, σ₃) in enumerate(σ₁_list)
    for (j,σ₄) in enumerate(σ₂_list)
        fs = [generate_lm(X, σ, μ₁, μ₂, σ₃, σ₄)[2] for _ in 1:100]
        
        for f in fs
            axes[i, j].plot(xs, f.(xs), "g", alpha = 0.1)
        end
        
        axes[i, j].set_xlim(extrema(xs))
        set_options(axes[i, j], "x", "y", "σ₁ =$(σ₃), μ₂ =$(σ₄)")
    end
end

tight_layout()



プライバシーポリシー お問い合わせ