NTTドコモR&Dの技術ブログです。

Stanで動かすベイズ的機械学習 ~医療費データの分析例~

本記事は、ドコモアドベントカレンダー2024 19日目の記事です🎄

こんにちは!NTTドコモ クロステック開発部の畑元です。業務ではヘルスケア領域におけるデータ分析やAI開発を行っています。

この記事ではベイズ推論による機械学習とRStanを用いた分析例をご紹介します。データサイエンス分野の方には馴染みのある話かもしれませんが、私はよく忘れてしまうので頭の整理も兼ねて書いていこうと思います。

※数式が崩れる方は、数式の上で右クリックして、Math Settings > Math Renderer > Common HTMLへ設定をご変更ください

1. はじめに

近年、AIに関する研究は急速に進歩し、あらゆる産業で活用されています。とりわけ生成AIの発展により自然なテキスト・高品質な動画像などを簡単に生成できるようになり、世の中に大きなインパクトを与えました。

生成モデルの根幹にある考え方は、その名の通り「データの生成プロセスをモデル化すること」です。言い換えれば、データの背後にある法則として確率的な構造を仮定し、得られたデータから構造を支配するパラメータを学習することで、新しいデータを生成可能にするという思想です。

データの生成プロセスを確率分布によって記述する手法を統計モデリングと呼びます。現象を確率的にモデリングする利点として、現実世界における不確実性を考慮可能、課題に合わせた柔軟なモデリング、比較的小サンプルでも合理的な推論が可能、等々が挙げられ、機械学習において重要な立ち位置を占めています。

中でもベイズ理論に基づくベイズ推論のアプローチは非常に強力です。 また、様々な統計・機械学習手法はベイズ理論の枠組みの中で記述することができ、この一つの枠組みを理解することで、様々な手法を柔軟に組み合わせる等の幅広い応用が可能になります。

2. ベイズ推論について

実際の分析に入る前にベイズ推論についてざっくり触れておきます。(あまり厳密には説明していませんので詳しく知りたい方は後述の参考書籍をご参照ください)

やりたいことは、データ\displaystyle{D}がとある確率分布に従って生成されると仮定し、その分布の形状を特徴づけるパラメータ\displaystyle{\Theta}を求めることです。\displaystyle{\Theta}を求めることができれば\displaystyle{D}を生成した分布の特徴を知ることができ、新しいデータの生成を行ったり、未知の入力に対する予測を行うことができるようになります。ベイズ推論はこれを実現する一つの方法です。

ベイズ推論自体は、ベイズの定理に基づく至ってシンプルな確率分布の更新プロセスです。まずは基本となるベイズの定理を導入していきます。

ベイズの定理

確率変数\displaystyle{A}\displaystyle{B}に対し、同時分布を\displaystyle{p(A,B)}とすると、\displaystyle{B}に関する周辺分布\displaystyle{A}が与えられたときの\displaystyle{B}条件付き分布を以下のように書きます。

  • 周辺分布
    \displaystyle{
p(B) = \int p(A,B) \mathrm{d} A \tag{1}
}
    ※確率変数が離散型の場合は、\displaystyle{p(B) = \sum _ A p(A,B)}
  • 条件付き分布
    \displaystyle{
p(B|A) = \frac{p(A,B)}{p(A)} \tag{2}
}

\displaystyle{(1)}および\displaystyle{(2)}より、以下のベイズ定理を導くことができます。

  • ベイズの定理
    \displaystyle{
p(B|A) = \frac{p(A|B) p(B)}{p(A)} = \frac{p(A|B) p(B)}{\int p(A,B) \mathrm{d} B} \tag{3}
}

ベイズの定理は単に確率の基本公式の変形によって導出される式に過ぎませんが、この式は面白い性質を示しています。\displaystyle{B}を原因、\displaystyle{A}を結果と解釈すると、ベイズの定理によれば、原因\displaystyle{B}から結果\displaystyle{A}が得られる確率\displaystyle{p(A|B)}をもとに、結果\displaystyle{A}が得られた時の原因\displaystyle{B}の確率\displaystyle{p(B|A)}を逆算できる、ということです。

ベイズ推論

この性質を利用すれば、実際に観測したデータ\displaystyle{D}(結果)からそのデータを生成した確率モデルのパラメータ\displaystyle{\Theta}(原因)を推定することができます。この関係をベイズの定理に当てはめて記述すると以下のようになります。 ※ベイズ推論の枠組みにおいては、推定対象のパラメータ\displaystyle{\Theta}も確率変数として扱います。

\displaystyle{
p(\Theta|D) = \frac{p(D|\Theta) p(\Theta)}{p(D)} \tag{4}
}

ここで、\displaystyle{p(D|\Theta)}尤度\displaystyle{p(\Theta)}事前分布\displaystyle{p(\Theta|D)}事後分布と呼びます。尤度はパラメータ\displaystyle{\Theta}が与えられた時のデータ\displaystyle{D}の発生しやすさを表す関数で、課題に合わせて自由にモデルを仮定することができます。事前分布はパラメータ\displaystyle{\Theta}の分布であり、そもそもどのような\displaystyle{\Theta}が得られやすいか、という事前に持っている仮説を表現します。事後分布もパラメータ\displaystyle{\Theta}の分布ですが、事前分布に尤度が乗算されることで、観測データを考慮して更新された\displaystyle{\Theta}の分布と解釈できます。また、分母の\displaystyle{p(D) = \int p(D, \Theta) \mathrm{d} \Theta}周辺尤度またはエビデンスと呼ばれます。周辺尤度は\displaystyle{\Theta}に関して定数項であり、事後分布の積分が\displaystyle{1}になることを保証するための正規化定数となります。

多くの場合、事後分布\displaystyle{p(\Theta|D)}を解析的に計算することは困難です。そこで、事後分布を求めるために主に以下3つの戦略をとることができます。

  1. 共役事前分布 尤度関数に規定した確率分布に対して、事前分布に特定の確率分布を規定することで、事後分布が事前分布と同型の分布になる組み合わせが知られており、これを共役事前分布といいます(例:ベルヌーイ分布に対するベータ分布、正規分布に対する逆ガンマ分布)。事前分布に共役事前分布を設定した場合、事後分布を解析的に計算でき非常に扱いやすいです。しかし、事前分布設計の自由が制限されるため、実用性は低いとされます。

  2. MCMC(マルコフ連鎖モンテカルロ法) MCMCは、近似的に事後分布に従う乱数を発生させる手法です。これにより事後分布を直接求めることはできなくても、シミュレーションから分布の特徴を知ることができます。MCMCのアルゴリズムとして、「メトロポリス法・ヘイスティング法」、「ギブスサンプリング」、「ハミルトニアンモンテカルロ法」などのいくつかの方法が存在します。

  3. 変分推論 変分推論は、求めたい\displaystyle{p(\Theta|D)}を別の新たな関数\displaystyle{q(\Theta; \eta)}により近似する手法です。この方法では、2つの確率分布の差異の大きさを表現する指標であるKLダイバージェンスを最小化する近似関数のパラメータ\displaystyle{\eta}の最適化問題を解きます。

\displaystyle{
   \eta_{\mathrm{opt.}} = \underset{\eta}{\mathrm{argmin}} \ \mathrm{KL}[q(\Theta;\eta) || p(\Theta|D)] \tag{5}
}

こちらの最適化も通常は解析解を得ることができないため、勾配法などの数値計算が利用されます。

近年では計算機の発達に伴い、MCMCや変分推論による近似推論が主流となっています。3. ではMCMCによる解法を試してみたいと思います。

ベイズ的機械学習

ここでは例として、教師あり学習がベイズ的にどのように表現できるかを考えていきます。

観測されている入力変数の集合\displaystyle{X}および出力変数の集合\displaystyle{Y}が与えられており、モデルパラメータを\displaystyle{W}とします。さらに、新規入力\displaystyle{x _ \ast}に対する予測出力を\displaystyle{y _ \ast}とします。このとき各確率変数間の関係は以下に示す図の通りです。このように確率変数間の依存関係をグラフ構造で表現したものをグラフィカルモデルといいます。

以上の関係を同時分布で記述すると次のようになります。

\displaystyle{
p(x_*, y_* , X, Y, W) = p(y_* | x_*, W)p(Y|X, W)p(x_*)p(X)p(W) \tag{6}
}

求めたい予測出力\displaystyle{y _ \ast}の分布は、観測データ\displaystyle{X, Y}および新規入力\displaystyle{x _ \ast}が与えられた時の条件付き分布\displaystyle{p(y _ \ast | x _ \ast, X, Y)}によって得られます。したがって、\displaystyle{(6)}式について\displaystyle{W}を周辺化し、\displaystyle{(x _ \ast, X, Y)}が得られた下での条件付き分布を求めると、

\displaystyle{
p(y_*|x_*, X, Y) = \frac{\int p(y_*|x_*, W)p(Y|X, W)p(x_*)p(X)p(W) \mathrm{d}W}{p(x_*, X, Y)} \\
 \quad \quad \quad \quad \quad \quad = \int p(y_*|x_*, W) p(W|X, Y) \mathrm{d}W \tag{7}
}

ここで、パラメータ\displaystyle{W}の事後分布はベイズ推論により以下のように求められます。

\displaystyle{
p(W|X, Y) = \frac{p(Y|X, W) p(W)}{p(Y|X)} \tag{8}
}

以上より、ベイズ推論に基づく教師あり機械学習は、\displaystyle{(8)}式の観測データ\displaystyle{X,Y}を用いてパラメータ\displaystyle{W}の事後分布を行う「学習」ステップ、\displaystyle{(7)}式の求めたパラメータの事後分布をもとに予測分布を求める「予測」ステップを行っていると解釈できます。

3. 実際に動かしてみる

準備

ベイズ的な機械学習の例を実際に動かしてみます。今回はKaggleで公開されているMedical Cost Personal Datasetsを使用します。このデータセットは、アメリカの医療保険契約者に関する基本的なデータであり、契約者の年齢・性別・BMI・子供の数・喫煙有無・居住地域・医療費の情報が含まれています。これを利用して契約者の基本情報から医療費を予測するような簡単な問題設定を考えてみます。

また、統計モデルの実装にあたり、確率的プログラミングというパラダイムを利用します。これは確率分布を扱うモデルをプログラムの形で記述し、そのモデルに基づいて推論やデータの生成を行うものです。確率的プログラミング言語には、Stan、PyMC、Edwardなど様々なものがありますが、今回は柔軟性が高く、扱いやすいStanを使っていきます。Stanで記述したモデルはRやPythonから呼び出して利用できます。今回の例ではRを使っています。

(参考)RStan Getting Started (Japanese))

探索的データ分析

まずはこのデータセットをRで読み出し、データの特徴を確認していきます。

# 基本ライブラリの読み込み
library(rstan)
library(ggplot2)
library(dplyr)

# データの読み込み
df <- read.csv(file='../data/insurance.csv') %>%
  as.data.frame()

# 重複しているレコードを削除
df <- df[!duplicated(df),]

# 先頭レコードを表示
head(df)

# 基本統計量の確認
summary(df)

# Charges(医療費)の分布を確認
plot_charges <- ggplot(df, aes(x = charges)) +
  geom_histogram(bins = 40, aes(y = ..density..), fill = "#C59A5A", color = "black", alpha = 0.7) +
  geom_density(color = "blue", size = 0.7) +
  labs(x = "Charges", y = "Density", title = "Charges Distribution")

plot_charges

医療費の分布は右にテールが長い非対称な分布であることがわかります。後ほどモデルに組み込む際に考慮すべき特徴になります。次に数値型の変数(年齢・BMI・子供の数)について分布と目的変数(医療費)との関係を確認していきます。

library(gridExtra)

# 数値型のカラムについてヒストグラムで分布を確認
# age
plot_age <- ggplot(df, aes(x = age)) +
  geom_histogram(binwidth = 1, aes(y = ..density..), fill = "#C55A71", color = "black", alpha = 0.7) +
  labs(x = "Age", y = "Density", title = "Age Distribution")

# bmi
plot_bmi <- ggplot(df, aes(x = bmi)) +
  geom_histogram(bins = 30, aes(y = ..density..), fill = "#5A9EC5", color = "black", alpha = 0.7) +
  geom_density(color = "blue", size = 0.7) +
  labs(x = "BMI", y = "Density", title = "BMI Distribution")

# children
plot_children <- ggplot(df, aes(x = children)) +
  geom_histogram(binwidth = 1, aes(y = ..density..), fill = "#5AC573", color = "black", alpha = 0.7) +
  labs(x = "Children", y = "Density", title = "Children Distribution")

# ageとchargesの散布図
plot_age_charges <- ggplot(df, aes(x = age, y = charges)) +
  geom_point(color = "#C55A71", alpha = 0.7) +
  labs(x = "Age", y = "Charges", title = "Age vs Charges")

# bmiとchargesの散布図
plot_bmi_charges <- ggplot(df, aes(x = bmi, y = charges)) +
  geom_point(color = "#5A9EC5", alpha = 0.7) +
  labs(x = "BMI", y = "Charges", title = "BMI vs Charges")

# childrenとchargesの散布図
plot_children_charges <- ggplot(df, aes(x = children, y = charges)) +
  geom_point(color = "#5AC573", alpha = 0.7) +
  labs(x = "Children", y = "Charges", title = "Children vs Charges")

# グラフをまとめて表示
grid.arrange(plot_age, plot_bmi, plot_children, plot_age_charges, plot_bmi_charges, plot_children_charges, ncol = 3)

おおむね以下の傾向が読み取れます。

  • 年齢については20歳以下のデータが多く、高年齢ほど医療費が高くなるが3種類の傾向に分かれているように見える
  • BMIについては30付近を平均として正規分布し、BMIが高いほど医療費が高くなる傾向とBMIに対し医療費が横ばいの2種類の傾向がみえる
  • 子供の数が増えるほど医療費が下がるが、子供が4人以上のデータは全体の5%程度であり信頼性には欠ける

次にカテゴリ型の変数(性別・喫煙有無・居住地域)について分布と目的変数との関係を確認していきます。

# カテゴリ型の変数について円グラフで分布を確認
create_pie_chart <- function(data, column, title) {
  data %>%
    count(!!sym(column)) %>%
    mutate(percentage = n / sum(n) * 100) %>%
    ggplot(aes(x = "", y = n, fill = !!sym(column))) +
    geom_bar(stat = "identity", width = 1, color = "white") +
    coord_polar(theta = "y") +
    labs(fill = column, title = title, y = "", x = "") +
    geom_text(aes(label = paste0(round(percentage, 1), "%")), 
              position = position_stack(vjust = 0.5), size = 4) +
    theme_minimal() +
    theme(axis.text = element_blank(),
          axis.ticks = element_blank(),
          panel.grid = element_blank())
}

# sex
plot_sex <- create_pie_chart(df, "sex", "Sex Distribution")

# smoker
plot_smoker <- create_pie_chart(df, "smoker", "Smoker Distribution")

# region
plot_region <- create_pie_chart(df, "region", "Region Distribution")

# sexとchargesのボックスプロット
plot_sex_charges <- ggplot(df, aes(x = sex, y = charges, fill = sex)) +
  geom_boxplot(alpha = 0.7, outlier.color = "red", outlier.shape = 16) +
  labs(x = "Sex", y = "Charges", title = "Sex vs Charges") +
  theme_minimal() +
  theme(legend.position = "none")

# smokerとchargesのボックスプロット
plot_smoker_charges <- ggplot(df, aes(x = smoker, y = charges, fill = smoker)) +
  geom_boxplot(alpha = 0.7, outlier.color = "red", outlier.shape = 16) +
  labs(x = "Smoker", y = "Charges", title = "Smoker vs Charges") +
  theme_minimal() +
  theme(legend.position = "none")

# regionとchargesのボックスプロット
plot_region_charges <- ggplot(df, aes(x = region, y = charges, fill = region)) +
  geom_boxplot(alpha = 0.7, outlier.color = "red", outlier.shape = 16) +
  labs(x = "Region", y = "Charges", title = "Region vs Charges") +
  theme_minimal() +
  theme(legend.position = "none")

# グラフをまとめて表示
grid.arrange(plot_sex, plot_smoker, plot_region, plot_sex_charges, plot_smoker_charges, plot_region_charges, ncol = 3)

  • 男女の割合は半々で、医療費に大きな違いはない
  • 喫煙者の割合は少なく、非喫煙者と比較して医療費が非常に大きい
  • 居住地域の割合はそれぞれ同程度で医療費に大きな違いはない

ベイズ線形回帰

各変数と医療費の基本的な関係が確認できました。今回はシンプルな線形回帰モデルに当てはめて考えてみます。線形回帰モデルでは、\displaystyle{D}次元の入力ベクトル\displaystyle{\mathbf{x} \in \mathbb{R}^D}を任意の基底関数\displaystyle{\boldsymbol{\phi(\cdot)}}により\displaystyle{H}次元の特徴空間に写像した\displaystyle{\boldsymbol{\phi(\mathbf{x})}}と重みベクトル\displaystyle{\mathbf{w} \in \mathbb{R}^H}の線形結合、および平均\displaystyle{0}、分散\displaystyle{\sigma^ 2}のガウス分布に従うノイズ項\displaystyle{\epsilon}により出力\displaystyle{y \in \mathbb{R}}が表現されます。

\displaystyle{
y = \mathbf{w}^\top \boldsymbol{\phi(\mathbf{x})} + \epsilon \\
\epsilon \sim \mathcal{N} (0, \sigma^2) \tag{9}
}

\displaystyle{(9)}式より、\displaystyle{y}は平均\displaystyle{\mathbf{w}^\top \boldsymbol{\phi(\mathbf{x})}}、分散\displaystyle{\sigma^ 2}のガウス分布に従う確率変数となります。したがって、\displaystyle{\mathbf{w}, \sigma^ 2}\displaystyle{\mathbf{x}}が得られた下での\displaystyle{y}の条件付き分布は、以下のように書けます。

\displaystyle{
p(y|\mathbf{x}, \mathbf{w}, \sigma^2) = \mathcal{N} (\mathbf{w}^\top \boldsymbol{\phi(\mathbf{x})}, \sigma^2) \tag{10}
}

推論(学習)したいパラメータは\displaystyle{\mathbf{w}, \sigma^ 2}であるため、事前分布\displaystyle{p(\mathbf{w})}, \displaystyle{p(\sigma^ 2)}を設定し、学習データ\displaystyle{\mathbf{x}, y}を観測した下での事後分布\displaystyle{p(\mathbf{w}, \sigma^ 2|\mathbf{x}, y)}をベイズ推論により求めます。\displaystyle{(8)}式を利用して、事後分布は以下で求められます。

\displaystyle{
p(\mathbf{w}, \sigma^2|\mathbf{x}, y) = \frac{p(y|\mathbf{x}, \mathbf{w}, \sigma^2) p(\mathbf{w}) p(\sigma^2)}{p(y|\mathbf{x})} \tag{11}
}

基底関数について、\displaystyle{\boldsymbol{\phi(\mathbf{x})} = \mathbf{x}}とする場合、一般的に重回帰と呼ばれるモデルになります。今回の問題で説明変数は\displaystyle{
\mathbf{x} = (1, x _ {\mathrm{age}}, x _ {\mathrm{sex}}, x _ {\mathrm{bmi}}, x _ {\mathrm{children}}, x _ {\mathrm{smoker}}, x _ {\mathrm{region}})
}、重みベクトルは\displaystyle{
\mathbf{w} = (\alpha, \beta _ {\mathrm{age}}, \beta _ {\mathrm{sex}}, \beta _ {\mathrm{bmi}}, \beta _ {\mathrm{children}}, \beta _ {\mathrm{smoker}}, \beta _ {\mathrm{region}})
}です。

また、今回は\displaystyle{\mathbf{w}}\displaystyle{\sigma^ 2}の事前分布を仮定するにあたり事前に持ち合わせている根拠や情報は特にありません。このような場合、十分に広い幅を持つ一様分布が事前分布としてよく用いられます。このような分布を無情報事前分布*1といいます。

以上のモデルをStanで記述すると以下のようになります。

// dataブロックではモデルに与える既知の観測データや固定値を定義します。
// 今回はテストデータに対する予測分布まで求めるためテストデータ用の変数を併せて定義しています。
data {
  int<lower=0> N;                // 訓練データのサンプル数
  vector[N] age;                 // age列
  vector[N] sex;                 // sex列
  vector[N] bmi;                 // bmi列
  vector[N] children;            // children列
  vector[N] smoker;              // smoker列
  vector[N] region;              // region列
  vector[N] y;                   // 目的変数 (charges)
  int<lower=0> N_test;           // テストデータのサンプル数
  vector[N_test] age_test;       // テストデータのage列
  vector[N_test] sex_test;       // テストデータのsex列
  vector[N_test] bmi_test;       // テストデータのbmi列
  vector[N_test] children_test;  // テストデータのchildren列
  vector[N_test] smoker_test;    // テストデータのsmoker列
  vector[N_test] region_test;    // テストデータのregion列
}

// parametersブロックでは推定すべき未知のパラメータを定義します。
parameters {
  real alpha;                    // 切片
  real beta_age;                 // ageの係数
  real beta_sex;                 // sexの係数
  real beta_bmi;                 // bmiの係数
  real beta_children;            // childrenの係数
  real beta_smoker;              // smokerの係数
  real beta_region;              // regionの係数
  real<lower=0> sigma;           // 残差の標準偏差
}

// modelブロックではパラメータの事前分布やモデル構造を定義します。
model {
  // 事前分布を設定する場合はここに記載します。
  // e.g. alpha ~ normal(0, 100);
  // 省略した場合は無情報事前分布として、十分に幅の広い一様分布が設定されます。

  // 尤度
  y ~ normal(
    alpha + 
    beta_age * age + 
    beta_sex * sex + 
    beta_bmi * bmi + 
    beta_children * children + 
    beta_smoker * smoker + 
    beta_region * region,
    sigma
  );
}

// generated quantitiesブロックは推定結果から派生する値や予測値を生成します。
// 今回はテストデータに対する予測分布まで求めるため以下で定義しています。
generated quantities {
  vector[N_test] y_test_pred;    
  for (i in 1:N_test) {
    y_test_pred[i] = normal_rng(
      alpha + 
      beta_age * age_test[i] + 
      beta_sex * sex_test[i] + 
      beta_bmi * bmi_test[i] + 
      beta_children * children_test[i] + 
      beta_smoker * smoker_test[i] + 
      beta_region * region_test[i],
      sigma
    );
  }
}

上記をmodel1.stanというファイル名で保存しておき、この後Rから読み込んで利用します。

# ラベルエンコーディング
df <- df %>%
  mutate(
    sex = as.numeric(factor(sex, levels = unique(sex))) - 1,
    smoker = 1 - (as.numeric(factor(smoker, levels = unique(smoker))) - 1),
    region = as.numeric(factor(region, levels = unique(region)))
  )

# データを訓練データとテストデータに分割
trainIndex <- createDataPartition(df$charges, p = 0.8, list = FALSE)
train_data <- df[trainIndex, ]
test_data <- df[-trainIndex, ]

# Stanに渡すデータリストの作成
stan_data <- list(
  N = nrow(train_data),
  age = train_data$age,
  sex = train_data$sex,
  bmi = train_data$bmi,
  children = train_data$children,
  smoker = train_data$smoker,
  region = train_data$region,
  y = train_data$charges,
  
  N_test = nrow(test_data),
  age_test = test_data$age,
  sex_test = test_data$sex,
  bmi_test = test_data$bmi,
  children_test = test_data$children,
  smoker_test = test_data$smoker,
  region_test = test_data$region
)

# Stanモデルの実行
# Stanでは"No-U-turn sampler (NUTS)"というMCMC手法がデフォルトで利用されます。
fit <- stan(
  file = "model1.stan",
  data = stan_data,
  iter = 4000,
  chains = 4
)

fitにはMCMCによって得られたすべてのパラメータのサンプリング結果が格納されています。以下のように結果を確認できます。

# サンプリング結果の確認(一部のみ抜粋)
print(fit, pars = c("alpha", "beta_age", "beta_sex", "sigma", "y_test_pred[1]", "y_test_pred[2]"))

結果には各パラメータのサンプル平均 (mean)、サンプル平均の標準誤差 (se_mean)、パーセンタイル値などが含まれています。最小二乗法による解法や最尤推定等の点推定アプローチと異なり、推定結果が不確実性を伴った分布の形で得られていることに注目してください。\displaystyle{\hat{R}} (Rhat)はモデルの収束指標であり、\displaystyle{\hat{R}}が1に近いほどサンプルの収束が良いことを示しています。

次にテストデータに対する予測値を確認してみます。y_test_pred[i]に予測分布が格納されていますが、これは1点に定まる値ではないので、今回はサンプル平均を代表値として確認してみます。

# サンプルの抽出
y_test_pred_samples <- rstan::extract(fit, pars = "y_test_pred")$y_test_pred

# サンプル平均値を予測値とする
y_test_pred_mean <- colMeans(y_test_pred_samples)
comparison <- data.frame(
  Actual = test_data$charges,
  Predicted = y_test_pred_mean,
  Smoker = factor(test_data$smoker)
)

axis_limit <- range(
  c(comparison$Actual, comparison$Predicted), 
  na.rm = TRUE,
  finite = TRUE
)

# 真値と推定値の散布図をプロット
ggplot(comparison, aes(x = Actual, y = Predicted, color = Smoker)) +
  geom_point(alpha = 0.6) +
  geom_abline(slope = 1, intercept = 0, color = "red", linetype = "dashed") +
  labs(
    title = "Actual vs Predicted Charges by Smoker Status", 
    x = "Actual Charges", 
    y = "Predicted Charges", 
    color = "Smoker"
  ) +
  theme_minimal() +
  coord_fixed(ratio = 1) +
  scale_x_continuous(limits = axis_limit) +
  scale_y_continuous(limits = axis_limit)

プロットは喫煙の有無で色分けをしています。こちらの結果を見ると低額層(非喫煙者)に対してはある程度推定ができていますが、高額層(喫煙者)をうまく表現できていません。喫煙の有無による特徴の違いなどをより詳しく調べて、モデル構造を改善する必要がありそうです。

階層モデル

次に非喫煙者と喫煙者で分けて傾向を見てみます。

# bmiとchargesの散布図(smoker別)
plot_bmi_smoker <- ggplot(df, aes(x = bmi, y = charges, color = smoker)) +
  geom_point(alpha = 0.7) +
  labs(title = "BMI vs Charges",
       x = "BMI",
       y = "Charges") +
  guides(colour=FALSE) +
  theme_minimal()

# ageとchargesの散布図(smoker別)
plot_age_smoker <- ggplot(df, aes(x = age, y = charges, color = smoker)) +
  geom_point(alpha = 0.7) +
  labs(title = "Age vs Charges",
       x = "Age",
       y = "Charges") +
  guides(colour=FALSE) +
  theme_minimal()

# グラフをまとめて表示
grid.arrange(plot_bmi_smoker, plot_age_smoker, ncol = 2)

喫煙の有無によって各変数の医療費に対する傾きや切片が大きく異なっていることが読み取れます。そこで喫煙有無グループのようなものを考え、グループごとに異なるパラメータを持つような以下のモデルを考えてみます。なお、医療費の分布は右にテールが伸びた非対称な分布であったため、この影響を吸収するために対数化した医療費を目的変数とします。また、モデルの簡単化のため、医療費に対して影響が小さいと考えられるsex, children, regionは考慮していません。

  1. 喫煙者(smoker)のモデル
\displaystyle{
\mathrm{ln} (y_i) \sim \mathcal{N} (\mu_{\mathrm{smoker}, i}, \sigma^2_{\mathrm{smoker}}) \\
\mu_{\mathrm{smoker}, i} = \alpha_\mathrm{smoker} + \\
\quad \quad \quad \quad \beta_\mathrm{age, smoker} \cdot x_{\mathrm{age}, i} + \\
\quad \quad \quad \quad \beta_\mathrm{bmi, smoker} \cdot x_{\mathrm{bmi}, i} + \\
\quad \quad \quad \quad \beta_\mathrm{smoker} \cdot x_{\mathrm{smoker}, i} \tag{12}
}
  1. 非喫煙者(non-smoker)のモデル
\displaystyle{
\mathrm{ln} (y_i) \sim \mathcal{N} (\mu_{\mathrm{non–smoker}, i}, \sigma^2_{\mathrm{non–smoker}}) \\\mu_{\mathrm{non–smoker}, i} = \alpha_\mathrm{non–smoker} + \\
\quad \quad \quad \quad \beta_\mathrm{age, non–smoker} \cdot x_{\mathrm{age}, i} + \\
\quad \quad \quad \quad \beta_\mathrm{bmi, non–smoker} \cdot x_{\mathrm{bmi}, i} + \\
\quad \quad \quad \quad \beta_\mathrm{smoker} \cdot x_{\mathrm{non–smoker}, i} \tag{13}
}

ただし、

\displaystyle{
\alpha_\mathrm{smoker} \sim \mathcal{N} (\mu_\alpha, \tau_\alpha) \\
\alpha_\mathrm{non–smoker} \sim \mathcal{N} (\mu_\alpha, \tau_\alpha) \\
\beta_\mathrm{age, smoker} \sim \mathcal{N} (\mu_{\beta, \mathrm{age}}, \tau_{\beta, \mathrm{age}}) \\
\beta_\mathrm{age, non–smoker} \sim \mathcal{N} (\mu_{\beta, \mathrm{age}}, \tau_{\beta, \mathrm{age}}) \\
\beta_\mathrm{bmi, smoker} \sim \mathcal{N} (\mu_{\beta, \mathrm{bmi}}, \tau_{\beta, \mathrm{bmi}}) \\
\beta_\mathrm{bmi, non–smoker} \sim \mathcal{N} (\mu_{\beta, \mathrm{bmi}}, \tau_{\beta, \mathrm{bmi}}) \tag{14}
}

このモデルでは、パラメータが喫煙の有無によって異なるが、これらはより上位の同じ分布から生成される(似たような傾向となる)という制約を持っています。このようなモデルは階層モデルと呼ばれます。このモデルをStanで記述すると、

data {
  int<lower=0> N;                  // 訓練データのサンプル数
  vector[N] age;                   // 年齢
  vector[N] bmi;                   // BMI
  int<lower=0,upper=1> smoker[N];  // 喫煙ステータス(0: 非喫煙者、1: 喫煙者)
  vector[N] y;                     // 医療費(対数変換したもの)

  int<lower=0> N_test;             // テストデータのサンプル数
  vector[N_test] age_test;         // テストデータの年齢
  vector[N_test] bmi_test;         // テストデータのBMI
  int<lower=0,upper=1> smoker_test[N_test]; // テストデータの喫煙ステータス
}

parameters {
  // 上位分布のパラメータ
  real mu_alpha;                  // 切片の平均
  real<lower=0> tau_alpha;        // 切片の標準偏差
  real mu_beta_age;               // 年齢係数の平均
  real<lower=0> tau_beta_age;     // 年齢係数の標準偏差
  real mu_beta_bmi;               // BMI係数の平均
  real<lower=0> tau_beta_bmi;     // BMI係数の標準偏差
  real mu_beta_smoker;            // 喫煙効果の平均
  real<lower=0> tau_beta_smoker;  // 喫煙効果の標準偏差

  // 喫煙グループごとのパラメータ
  real alpha_smoker;              // 喫煙者の切片
  real alpha_non_smoker;          // 非喫煙者の切片
  real beta_age_smoker;           // 喫煙者の年齢係数
  real beta_age_non_smoker;       // 非喫煙者の年齢係数
  real beta_bmi_smoker;           // 喫煙者のBMI係数
  real beta_bmi_non_smoker;       // 非喫煙者のBMI係数
  real beta_smoker;               // 喫煙効果の回帰係数

  // 標準偏差
  real<lower=0> sigma_smoker;
  real<lower=0> sigma_non_smoker;
}

model {
  // 上位分布の事前分布
  mu_alpha ~ normal(0, 10);
  tau_alpha ~ cauchy(0, 2);
  mu_beta_age ~ normal(0, 1);
  tau_beta_age ~ cauchy(0, 2);
  mu_beta_bmi ~ normal(0, 1);
  tau_beta_bmi ~ cauchy(0, 2);
  mu_beta_smoker ~ normal(0, 1);
  tau_beta_smoker ~ cauchy(0, 2);

  // 喫煙グループごとのパラメータの事前分布(階層構造)
  alpha_smoker ~ normal(mu_alpha, tau_alpha);
  alpha_non_smoker ~ normal(mu_alpha, tau_alpha);
  beta_age_smoker ~ normal(mu_beta_age, tau_beta_age);
  beta_age_non_smoker ~ normal(mu_beta_age, tau_beta_age);
  beta_bmi_smoker ~ normal(mu_beta_bmi, tau_beta_bmi);
  beta_bmi_non_smoker ~ normal(mu_beta_bmi, tau_beta_bmi);
  beta_smoker ~ normal(mu_beta_smoker, tau_beta_smoker);

  // 標準偏差の事前分布
  sigma_smoker ~ normal(0, 1);
  sigma_non_smoker ~ normal(0, 1);

  // 尤度
  for (i in 1:N) {
    if (smoker[i] == 1) {
      y[i] ~ normal(
        alpha_smoker + beta_age_smoker * age[i] + beta_bmi_smoker * bmi[i] + beta_smoker * smoker[i],
        sigma_smoker
      );
    } else {
      y[i] ~ normal(
        alpha_non_smoker + beta_age_non_smoker * age[i] + beta_bmi_non_smoker * bmi[i] + beta_smoker * smoker[i],
        sigma_non_smoker
      );
    }
  }
}

generated quantities {
  vector[N_test] y_test_pred;

  for (i in 1:N_test) {
    if (smoker_test[i] == 1) {
      y_test_pred[i] = exp(normal_rng(
        alpha_smoker + beta_age_smoker * age_test[i] + beta_bmi_smoker * bmi_test[i] + beta_smoker * smoker_test[i],
        sigma_smoker
      ));
    } else {
      y_test_pred[i] = exp(normal_rng(
        alpha_non_smoker + beta_age_non_smoker * age_test[i] + beta_bmi_non_smoker * bmi_test[i] + beta_smoker * smoker_test[i],
        sigma_non_smoker
      ));
    }
  }
}

上記をmodel2.stanとして保存し、Rから呼び出して推論を実行します。

# chargesの対数変換
df <- df %>%
  mutate(
    log_charges = log(charges)  
  )

# データを訓練データとテストデータに分割
trainIndex <- createDataPartition(df$log_charges, p = 0.8, list = FALSE)
train_data <- df[trainIndex, ]
test_data <- df[-trainIndex, ]

# 訓練データの平均と標準偏差を計算
age_mean <- mean(train_data$age)
age_sd <- sd(train_data$age)
bmi_mean <- mean(train_data$bmi)
bmi_sd <- sd(train_data$bmi)

# 訓練データの正規化
train_data <- train_data %>%
  mutate(
    age = (age - age_mean) / age_sd,
    bmi = (bmi - bmi_mean) / bmi_sd
  )

# テストデータの正規化
test_data <- test_data %>%
  mutate(
    age = (age - age_mean) / age_sd,
    bmi = (bmi - bmi_mean) / bmi_sd
  )

# Stan用データを準備 (個別のベクトルとして渡す)
stan_data <- list(
  N = nrow(train_data),
  age = train_data$age,
  bmi = train_data$bmi,
  smoker = train_data$smoker,
  y = train_data$log_charges,
  N_test = nrow(test_data),
  age_test = test_data$age,
  bmi_test = test_data$bmi,
  smoker_test = test_data$smoker
)

# モデルを実行
fit <- stan(
  file = "model2.stan",
  data = stan_data,
  iter = 4000,
  chains = 4
)

# サンプリング結果の確認(一部のみ抜粋)
print(fit, pars = c("alpha_smoker", "alpha_non_smoker", "beta_bmi_smoker", "beta_bmi_non_smoker"))

サンプリング結果を見ると喫煙の有無によって各パラメータの値が異なっていることがわかります。次に予測値の結果を確認します。

# サンプルの抽出
y_test_pred_samples <- rstan::extract(fit, pars = "y_test_pred")$y_test_pred

y_test_pred_mean <- colMeans(y_test_pred_samples)
comparison <- data.frame(
  Actual = test_data$charges,
  Predicted = y_test_pred_mean,
  Smoker = factor(test_data$smoker)
)

axis_limit <- range(
  c(comparison$Actual, comparison$Predicted), 
  na.rm = TRUE,
  finite = TRUE
)

ggplot(comparison, aes(x = Actual, y = Predicted, color = Smoker)) +
  geom_point(alpha = 0.6) +
  geom_abline(slope = 1, intercept = 0, color = "red", linetype = "dashed") +
  labs(
    title = "Actual vs Predicted Charges", 
    x = "Actual Charges", 
    y = "Predicted Charges", 
    color = "Smoker"
  ) +
  theme_minimal() +
  coord_fixed(ratio = 1) +
  scale_x_continuous(limits = axis_limit) +
  scale_y_continuous(limits = axis_limit)

全体的に予測性能が改善しましたが、未だ喫煙者の残差が大きかったり、非喫煙者の一部がうまく表現できていないことがわかります。本記事での分析はここまでになりますが、さらなる改善案としては、変数間の交互作用を導入する、喫煙者の医療費に対する二峰性の分布を混合分布モデルで表現する、地域差を考慮した階層モデルを導入する、知見に基づいた適切な分布の設定などが挙げられます。

4. おわりに

この記事ではStanを利用したベイズ的な機械学習についてご紹介しました。今回の問題設定は簡単なものであり、一般的に利用する線形回帰モデルで求めた結果と大きくは変わらないと思います。しかし、ベイズ理論では事前知識や不確実性をモデルに組み込むことができ、解釈性の高さ、小規模データに対する頑健さ、逐次学習など様々な魅力があります。普段はBoostingモデルに突っ込んで終わり!という方も実際にデータを見て、Stanなどでモデルを作って動かしてみると確率的なモデリングの面白さが掴めてくるかと思いますので、興味があれば触ってみてください!

ありがとうございました。

参考書籍

  1. 須山敦志. ベイズ推論による機械学習入門. 講談社, 2017.
  2. 須山敦志. ベイズ深層学習. 講談社, 2019.
  3. 松浦健太郎. StanとRでベイズ統計モデリング. 共立出版, 2017.

*1:補足として、近年は事前分布において無情報事前分布ではなく、最低限の情報を与える弱情報事前分布を用いた方が良いとされております。例えば、標準化偏回帰係数であればその値は高々-1~1におさまるため、平均は0、scaleは1〜2ほど、自由度3〜7ほどのt分布による弱情報事前分布が推奨されております。自由度3〜7のt分布の理由は、ファットテールで裾がそこそこ厚く、ロバスト性を担保できるためです。(自由度の補足として、自由度1のt分布、つまりcauchy分布だと、裾が厚すぎてロバストだが無情報事前分布に近いため推奨されません。また、自由度が8以上のt分布だと、正規分布に近づきショートテールな分布となり、ロバスト性が担保できなくなるため推奨されておりません。ただし、事前分布を正規分布とすると、L1ノルム(Ridge)と同じ働きをするため、使い分けすることが大事です。)