NTTドコモR&Dの技術ブログです。

JAXoptで実践するDecision-focused Learning 〜機械学習×数理最適化の新たなアプローチ〜

はじめまして.NTTドコモ サービスイノベーション部の髙橋優輝です.普段の業務ではデータサイエンスや AI 等を活用した業務効率化や意思決定支援に携わっております.本記事では,機械学習と数理最適化の融合の一つの形である Decision-focused Learning について解説します.

Prediction-focused Learning (PFL) と Decision-focused Learning (DFL) の概要

1. はじめに:「機械学習の予測精度が良い」ことは「良い意思決定」につながるのか?

機械学習と数理最適化の融合は,データ駆動型の意思決定において極めて重要なテーマです.現実社会における多くの最適化問題では,目的関数や制約条件に含まれる係数(パラメータ)が未知であり,これらを過去のデータや特徴量から推定する必要があります.具体的には,以下のようなユースケースが挙げられます.

応用例 機械学習による予測対象 数理最適化の目的
在庫管理 将来の商品需要 在庫保管コストと機会損失の最小化
配送計画 道路の混雑状況・移動時間 総配送時間の最小化
ポートフォリオ 各資産の期待収益率 リスクを抑えた収益の最大化
発電計画 電力需要・再エネ発電量 発電コストの最小化

こうした問題に対する標準的なアプローチは,まず未知のパラメータを機械学習モデルで予測し,その予測値をソルバーに入力して解を得るという Prediction-focused Learning (PFL)1 の枠組みです.

もちろん,訓練を通じてあらゆる入力に対して予測誤差が厳密に $0$ となる完璧な機械学習モデルを構築できるのであれば,このアプローチは最適となります.しかし,現実問題としてそのようなモデルは存在しません.ここで一つの重要な問いが生まれます.

「機械学習の予測精度が良い」ことは「良い意思決定」につながるのか?

予測モデルの誤差が最終的な意思決定に与える影響は一様であるとは限りません.最適化問題の構造によっては,ある方向への予測のズレが決定変数(アクション)に致命的な悪影響を及ぼす一方で,別の方向へのズレは許容されるといった非対称性が存在することがあります.

そのため,単に機械学習で予測値と真値の距離を縮めるだけでは,必ずしもビジネス上のゴール(目的関数の最適化)に直結しない場合があります.そこで,予測自体の誤差ではなく,最終的な意思決定の質(コスト)を最小化するようにモデルを訓練させる Decision-Focused Learning (DFL) というアプローチが注目されています.DFLのサーベイ論文としてMandi et al. (2024)があります.

本記事では,微分可能な最適化ライブラリ JAXopt を活用してこの DFL を実装し,従来手法 (PFL) との挙動の違いを数値実験を通じて確認します.

2. 問題設定

本記事では,ある特徴量に基づいて最適化問題のパラメータを予測し,その結果を用いて意思決定を行う文脈付き最適化 (Contextual Optimization) の枠組みを考えます.

具体的には,以下の流れで意思決定が行われます.

  1. 特徴量の観測: 意思決定者は,特徴量 $z$ を観測します.
  2. パラメータの予測: 特徴量 $z$ から,最適化問題の目的関数に関わる未知のパラメータ $c$ を予測します.ここでは,モデルパラメータ $\theta$ を持つ予測モデルを $M_\theta$ とし,予測値を $\hat{c} \mathrel{:=} M_\theta(z)$ とします.
  3. 意思決定(最適化): 予測されたパラメータ $\hat{c}$ を真の値であると仮定して,次の最適化問題を解き,最適解 $x^{\ast}(\hat{c})$ を決定します. $$ x^{\ast}(\hat{c}) \mathrel{:=} \underset{x \in \mathcal{F}}{\text{argmin}} \ f(x, \hat{c}). $$ ここで,$\mathcal{F}$ は実行可能領域であり,$f$ は目的関数です.

最終的な目的は,予測に基づいて導かれた解 $x^{\ast}(\hat{c})$ が,真のパラメータ $c$ の下でも良い性能を発揮する(低いコストを実現する)ことです.つまり,次式で定義される Regret を最小化することが目的となります. $$ R(x^{\ast}(\hat{c}), c) \mathrel{:=} f(x^{\ast}(\hat{c}), c) - f(x^{\ast}(c), c). $$ ここで,第二項 $f(x^{\ast}(c), c)$ は,もし事前に真のパラメータ $c$ を知っていたら達成できたはずの最小コスト(理論値)です.この値は真のパラメータ $c$ のみに依存し,予測モデルのパラメータ $\theta$ には依存しないため,訓練時の勾配計算においては定数として無視できます.したがって,訓練における実質的な目的は,第一項の $f(x^{\ast}(\hat{c}), c)$ を最小化するようなモデルパラメータ $\theta$ を見つけることに帰着します.なお,定義より,Regret は常に非負となります.

ビジネス文脈においては,$z, c, f, x^{\ast}$にはそれぞれ以下のように対応します.

応用例 特徴量 $z$ パラメータ $c$ 目的関数 $f(x, c)$ 決定変数 $x^{\ast}$
在庫管理 過去の売上、天気 将来の需要量 在庫・欠品コスト 商品の発注数
配送計画 交通量、工事情報 区間の所要時間 総配送時間 配送ルート
ポートフォリオ 過去の株価、金利 期待収益率 リスク・損失 資産の投資配分
広告配信 ユーザー属性 クリック率 (CTR) 機会損失 配信する広告
発電計画 気象予報 再エネ発電量 燃料コスト 発電スケジュール

3. アプローチ:PFL と DFL

前章の問題設定に対し,機械学習モデル $M_\theta$ をどのように訓練するかによって,主に2つのアプローチが存在します.

3.1 Prediction-focused Learning (PFL)

Prediction-focused Learning (PFL) は,最も一般的で直観的な手法です.この手法では,下流にある最適化問題の構造は一旦無視し,とにかく予測値 $\hat{c}$ を真のパラメータ $c$ に近づけること を目的とします.

具体的には,訓練データを $\left\{(z_{i}, c_{i})\right\}_{i = 1}^N$ として,平均二乗誤差(MSE)などの損失関数を用いてモデルを訓練します.例えば,損失関数として MSE を採用した場合,以下の最適化問題を解くことになります. $$ \underset{\theta}{\text{minimize}}\qquad L_{\text{PFL}}(\theta) \mathrel{:=} \frac{1}{N}\sum_{i=1}^N \|M_\theta(z_i) - c_i\|^2. $$

このアプローチの最大の利点は,モデルの訓練のプロセスが最適化ソルバーから切り離されているため,訓練が容易かつ高速である点です.一般的な回帰問題として扱えるため,既存の機械学習ライブラリや知見をそのまま適用できます.しかし,1章で述べた通り,予測誤差の最小化が必ずしもコストの最小化を意味しないタスクにおいては,最適なアプローチとは言えません

3.2 Decision-focused Learning (DFL)

PFLの課題を解決するために提案されたのが Decision-focused Learning (DFL) です.DFLは,予測が多少外れていても,結果的にビジネス上の損失が少ない意思決定ができれば良い という思想に基づいており,最適化問題の構造(コストの非対称性や制約条件)を訓練プロセスに組み込み,最終的な目的関数値を直接減らすようにモデルを訓練します.

例えば,ある製品の需要を予測して在庫量を決める問題を考えます.真の需要が「100個」であるのに対し,機械学習モデルAが「110個(過剰)」,モデルBが「90個(不足)」と予測したとします.MSE の観点では,両者の誤差(絶対値10)は等価です.しかし,ビジネスの現場では,「在庫切れによる顧客の信頼失墜や機会損失」のダメージの方が,「売れ残りの保管コスト」よりも遥かに大きいことがよくあります.この場合,あえて少し多めに予測するモデルの方が,ビジネス上の損失 (Regret) は小さくなります2

PFLではプラスとマイナスの誤差を等しく罰するため,こうした構造を学習できません.一方,DFL では最終的なコスト $f(x^*(\hat{c}), c)$ を見るため,自然と損失の大きい方向へのミスを避けるようにパラメータが調整されます.

DFLでは,Regret の第一項を損失関数として直接扱います.すなわち,モデルの訓練時に次の二層構造の最適化問題を解きます. $$ \begin{aligned} & \underset{\theta}{\text{minimize}} && L_{\text{DFL}}(\theta) \mathrel{:=} \frac{1}{N}\sum_{i=1}^N f(x^{\ast}(\hat{c}_i), c_i), \\ & \text{subject to} && \hat{c}_i \mathrel{:=} M_\theta(z_i), \\ & && x^{\ast}(\hat{c}_i) \mathrel{:=} \underset{x \in \mathcal{F}}{\text{argmin}} \ f(x, \hat{c}_i ). \end{aligned} $$

この問題を勾配法で解くためには,最終的なコスト $L_{\text{DFL}}$ のモデルパラメータ $\theta$ に対する微分を計算する必要があります.連鎖律を用いると,微分は以下のように分解できます. $$ \frac{d L_{\text{DFL}}}{d \theta} = \underbrace{\frac{d f(x^\ast, c)}{d x^\ast}}_{\text{(A)}} \cdot \underbrace{\frac{d x^\ast(\hat{c})}{d \hat{c}}}_{\text{(B)}} \cdot \underbrace{\frac{d \hat{c}}{d \theta}}_{\text{(C)}}. $$

  • (A) 目的関数の微分: 最適化問題の目的関数 $f$ の $x$ に関する微分です.通常,$f$ は既知の関数であり計算可能です.
  • (C) 予測モデルの微分: ニューラルネットワーク等の予測モデルの微分です.PyTorchJAX の自動微分機能で容易に計算可能です.
  • (B) 最適解の予測値に対する微分: ここが DFL の核心であり,最大の難所です.最適化問題の入力 $\hat{c}$ が変化したときに,最適解 $x^{\ast}(\hat{c})$ がどう変化するかを表します.

一般には,写像 $\hat{c} \mapsto x^{\ast}(\hat{c})$ は閉じた形で表現できず,微分計算が困難です.また,例えば,線形計画問題について,この写像は(連続な)区分的定数関数となります.この場合,いくつかの点で微分が定義されず,ほとんど至るところで微分が0となり,勾配情報が逆伝播しません.

しかし,問題のクラスを制限したり,適切な平滑化を行ったりすれば,(B) は計算可能です.例えば,目的関数 $f$ が滑らかな強凸関数であるような制約なし最適化問題であれば,最適性の一次の条件から,次式が成り立ちます. $$ f_x(x^{\ast}(\hat{c}), \hat{c}) = 0. $$ この両辺を $\hat{c}$ で微分すると,次式が得られ,(B) が解析的に求まります. $$ \frac{d x^{\ast}(\hat{c})}{d \hat{c}} = - (f_{xx}(x^{\ast}(\hat{c}), \hat{c}))^{-1} f_{x\hat{c}}(x^{\ast}(\hat{c}), \hat{c}). $$ この考え方は,制約あり最適化問題(制約条件を目的関数に取り込む,すなわち,緩和問題を考える)や線形計画問題(正則化項を加えて滑らかにする)にも拡張可能です.前述のサーベイ論文では,これ以外にも,制約付き最適化問題の最適性の条件について陰関数微分 (Implicit Differentiation) を用いる手法反復法の反復を展開 (Unroll) して自動微分を活用する方法が挙げられています.

また,上記サーベイ論文には,微分情報を使わない DFL の解法について言及されています.

4. JAXopt の特徴

前章で確認した通り,DFL を実現するためには,最適化問題を解く工程そのものを微分可能にする必要があります.理論上は,陰関数微分などを用いればよいですが,これをゼロから実装し,高速に動作させるのは容易ではありません.

そこで本記事では,DFL を実践するために JAXopt という Python ライブラリを採用しました.本章では,JAXopt の特徴について解説します.

公式ドキュメントを見ると特徴として3つ挙げられています.

1. ハードウェアアクセラレーション (GPU / TPU)

JAXopt は JAX 上に構築されているため,最適化ソルバーの反復計算そのものを GPU や TPU 上で実行可能です.DFL の学習ループでは,何千・何万回と最適化問題を解く必要があるため,CPU ベースの従来のソルバー(scipy.optimize や商用ソルバー等)と比較して,圧倒的な計算速度の差が生まれます.

2. 自動バッチ処理 (vmap)

JAX の vmap 機能に対応しており,複数の最適化問題を並列に解く(バッチ処理する)ことが容易です.これにより,ミニバッチごとの学習において GPU の並列演算性能を最大限に引き出すことができます.

3. 微分可能

通常の最適化ソルバーは最適解 $x^*$ を出力して終わりですが,JAXopt はその計算過程を通じて勾配を逆伝播させることができます.JAXopt では,最適化ソルバーをニューラルネットワークの一部(微分可能な層) として扱うことができます.実装上は,ソルバーの引数 implicit_diff を制御することで,以下のような微分計算手法を選択できます. * implicit_diff=True: 陰関数定理に基づき,最適性の条件のみを用いて勾配を計算します(メモリ効率が良い). * implicit_diff=False: 最適化ソルバーの反復計算を展開し,計算グラフを遡って勾配を計算します.

JAXopt は非常に多機能で,制約なし/あり最適化,二次計画,非平滑最適化,求根,不動点探索など,多岐にわたる問題設定に対応したソルバーを提供しており,DFL を実践する上で非常に強力な武器となります.

5. 数値実験

本章では,JAXopt を用いた DFL が実際にどのように機能するかを,具体的な数値実験を通して確認します.比較対象として,PFL を用います.

5.1 実験設定

数値実験においては,同一のデータ・機械学習モデルを利用し,PFL と DFL のそれぞれについて,文脈付き二次計画問題を解きました.

  • データ生成: 入力特徴ベクトル $\boldsymbol{z}_i \in \mathbb{R}^5$ から,最適化問題のパラメータ(目的関数の線形項の係数ベクトル)$\boldsymbol{c}_i \in \mathbb{R}^2$ を非線形な関係で生成しました.このペアを1000個用意し,訓練データ,テストデータをそれぞれ800個,200個に分割しました.
  • 予測モデル: PFL と DFL について共通の構造を持つ多層パーセプトロン (MLP) を用いました.
  • 評価指標: テストデータに対する MSE と Regret の値を報告します.
    1. MSE: 機械学習による予測精度 $$ \text{MSE}_{\text{overall}} \mathrel{:=} \frac{1}{N} \sum_{i = 1}^{N} | \hat{\boldsymbol{c}}_i - \boldsymbol{c}_i |^2 = \frac{1}{N} \sum_{i = 1}^{N}(\hat{c}_{i, 0} - c_{i, 0})^{2} + \frac{1}{N} \sum_{i = 1}^{N}(\hat{c}_{i, 1} - c_{i, 1})^{2} \mathrel{=:} \text{MSE}_{c0} + \text{MSE}_{c1}. $$
    2. Regret: 意思決定の質(真のパラメータの下での,予測解と真の最適解のコスト差の平均) $$ \text{Regret} \mathrel{:=} \frac{1}{N} \sum_{i=1}^{N} \left( f(\boldsymbol{x}^{\ast}(\hat{\boldsymbol{c}}_i), \boldsymbol{c}_i) - f(\boldsymbol{x}^{\ast}(\boldsymbol{c}_i), \boldsymbol{c}_i) \right). $$ ここで,決定変数ベクトルは $\boldsymbol{x} \in \mathbb{R}^2$ です.
  • 訓練: PFL では訓練データに対する MSE を,DFLでは訓練データに対する Regret を最小化するように訓練を行いました.

5.2 Case 1: 制約なし二次計画問題

Case 1では下流の最適化問題を次式で定義しました. $$ \underset{\boldsymbol{x}}{\text{minimize}} \quad \frac{1}{2} \boldsymbol{x}^\top \boldsymbol{Q} \boldsymbol{x} + \boldsymbol{c}^\top \boldsymbol{x}. $$ ここで,$\boldsymbol{c}$ は機械学習により予測されるパラメータです.また,数値実験では,行列 $\boldsymbol{Q}$ を次のように設定しました. $$ \boldsymbol{Q} = \begin{pmatrix} 20.0 & 1.0 \\ 1.0 & 0.2 \end{pmatrix}. $$ ここで,意図的に$Q_{00}$を$Q_{11}$より非常に大きな値に設定しました.この設定下では,良い意思決定を行うためには $c_0$ よりも $c_1$ の予測精度が支配的な要因となります.

実験結果

数値実験の結果は次のようになりました.

Model MSE (overall) MSE (c0) MSE (c1) Regret
PFL 49.75 22.82 26.94 86.55
DFL 426.18 417.82 8.36 34.50

最適化問題の係数パラメータ $\boldsymbol{c}$ の予測精度は PFL の方が良いですが,Regret の値は DFLの方が良くなっています.

また,テストデータについて実際に得られた最適解をプロットしたところ,図1に示す結果が得られました.これを見ると,DFL が PFL よりも真の最適解に近い解が得られるように訓練できていることが読み取れます.

図1: 制約なし最適化問題に対する真の最適解(黒い星),PFL で得られた最適解(赤い三角),DFL で得られた最適解(青い丸)のプロット.真の最適解と各手法の最適解の間に線を引いています.

考察

上記実験結果から,PFL では MSE を最小化するように機械学習モデルを訓練するため,$c_0, c_1$ 共にバランスよく誤差を減らしています.一方,DFL では Regret を最小化するように機械学習モデルを訓練しており,感度の高い $c_1$ の予測精度が高くなっています.この結果から,良い予測よりも良い意思決定を優先するという DFL の性質が見て取れます.

5.3 Case 2: 制約あり二次計画問題

Case 2では,Case 1の問題設定に線形不等式制約を加えた問題を考えます. $$ \begin{aligned} & \underset{\boldsymbol{x}}{\text{minimize}} && \frac{1}{2} \boldsymbol{x}^\top \boldsymbol{Q} \boldsymbol{x} + \boldsymbol{c}^\top \boldsymbol{x}, \\ & \text{subject to} && \boldsymbol{G} \boldsymbol{x} \le \boldsymbol{h}. \end{aligned} $$

数値実験では,行列 $\boldsymbol{Q}, \boldsymbol{G}$ とベクトル $\boldsymbol{h}$ を次のように設定しました. $$ \boldsymbol{Q} = \begin{pmatrix} 20.0 & 1.0 \\ 1.0 & 0.2 \end{pmatrix}, \boldsymbol{G} = \begin{pmatrix} 0.0 & -1.0 \\ 2.0 & 3.0 \end{pmatrix}, \quad \boldsymbol{h} = \begin{pmatrix} 0.0 \\ 1000.0 \end{pmatrix}. $$

実験結果

数値実験の結果は次のようになりました.

Model MSE (overall) MSE (c0) MSE (c1) Regret
PFL 49.75 22.82 26.94 15.70
DFL 4130.17 629.40 3500.78 3.07

制約ありの場合,DFL の挙動はさらに極端となっています.DFL の MSE はパラメータの予測精度としては PFL に大きく劣りますが,Regret では PFL を圧倒しています.

また,テストデータについて実際に得られた最適解をプロットしたところ,図2に示す結果が得られました.これを見ると,Case 1と同様に,DFL で求めた解の方が真の最適解に近くなっていることが読み取れます. さらに,PFL と DFL の両方について,制約 $(x_1 \geq 0, 2x_0 + 3x_1 \leq 1000)$ が守られていることも読み取れます.

図2: 制約あり最適化問題に対する真の最適解(黒い星),PFL で得られた最適解(赤い三角),DFL で得られた最適解(青い丸)のプロット.真の最適解と各手法の最適解の間に線を引いています.

考察

上記実験結果から,制約条件があるときについても,良い予測よりも良い意思決定を優先するという DFL の性質が見て取れます.

また,図2を見ると DFLの解は真の解と同様に制約の境界線上に強く吸着していることが分かります.今回のように制約が最適解を決定づける場合,中途半端な予測をして制約が有効にならないよりも,あえて制約が有効になりやすいように極端なパラメータ予測を行う方が,結果として Regret を小さくできます.つまり,DFL における非常に大きな予測誤差は訓練の失敗ではなく,解を安定して制約の境界上に留めるようにモデルが訓練されたと解釈でき,制約を意識した訓練が行われていると推測されます.

一方,図2を見ると制約の境界線上から離れた領域についても,Case 1と同様に DFL は PFL より真の解に近い解が得られており,制約を有効にするかどうかを柔軟に訓練できていると推測されます.

5.4 数値実験に使用したプログラム

プログラムの全文を掲載する前に、DFL の実装において特に重要な微分可能な最適化ソルバー損失関数の定義について解説します.

微分可能な最適化ソルバーの実装

まず,二次計画問題を表す抽象クラスとして,共通インターフェース QP を定義しました.このクラスは,目的関数値をバッチで計算する objective_value と,最適化問題をバッチで解く solve メソッドを持ちます.

class QP(Protocol):
    """二次計画問題の共通インターフェース"""

    def objective_value(self, x: jax.Array, c: jax.Array) -> jax.Array:
        """目的関数値を計算"""
        ...

    def solve(self, c: jax.Array) -> jax.Array:
        """最適解を計算"""
        ...

Case 1では UnconstrainedQP クラスを用いました.制約なし二次計画問題の最適解は閉じた形で求まるため,反復的なソルバーは不要です.ここでは行列計算のみで実装しています.バッチ処理を利用するため,vmap でラップしています.

class UnconstrainedQP(QP):
    """制約なし二次計画問題"""

    def __init__(self) -> None:
        """問題の初期化"""
        self.Q = jnp.array([[20.0, 1.0], [1.0, 0.2]])
        self.Q_inv = jnp.linalg.inv(self.Q)

    def objective_value(self, x: jax.Array, c: jax.Array) -> jax.Array:
        """目的関数の値を計算"""

        def _objective_value(x_i: jax.Array, c_i: jax.Array) -> jax.Array:
            quad = 0.5 * x_i.T @ self.Q @ x_i
            linear = c_i.T @ x_i
            return quad + linear

        return vmap(_objective_value)(x, c)

    @property
    def solve(self) -> Callable[[jax.Array], jax.Array]:
        """最適解を計算"""

        def _solve(c: jax.Array) -> jax.Array:
            return -self.Q_inv @ c

        return jit(vmap(_solve))

Case 2では ConstrainedQP クラスを用いました.不等式制約が含まれるため解析的には解けず,数値最適化ソルバーが必要になります.JAXopt では二次計画ソルバーとして5つ実装されています.その内,Case 2の不等式制約付きの二次計画問題については,jaxopt.CvxpyQPjaxopt.OSQP を活用可能です.今回の実装では,より高速な jaxopt.OSQP を用いました.

以下の実装の通り,solve メソッド内でソルバーの run 関数を呼び出すだけで実装が完了します.一見するとフォワードパス(最適化の実行)しか記述していないように見えますが,JAXopt は内部でバックワードパス(勾配計算)を定義しています.

class ConstrainedQP(QP):
    """制約あり二次計画問題"""

    def __init__(self) -> None:
        """問題の初期化"""
        self.Q = jnp.array([[20.0, 1.0], [1.0, 0.2]])

        self.G = jnp.array(
            [
                [0.0, -1.0],
                [2.0, 3.0],
            ],
        )
        self.h = jnp.array([0.0, 1000.0])

        # 微分可能なQPソルバーの初期化
        self.osqp = jaxopt.OSQP()

    def objective_value(self, x: jax.Array, c: jax.Array) -> jax.Array:
        """目的関数の値を計算"""

        def _objective_value(x_i: jax.Array, c_i: jax.Array) -> jax.Array:
            quad = 0.5 * x_i.T @ self.Q @ x_i
            linear = c_i.T @ x_i
            return quad + linear

        return vmap(_objective_value)(x, c)

    @property
    def solve(self) -> Callable[[jax.Array], jax.Array]:
        """最適解を計算"""

        def _solve(c_sample: jax.Array) -> jax.Array:
            sol = self.osqp.run(
                params_obj=(self.Q, c_sample),
                params_ineq=(self.G, self.h),
            )
            return sol.params.primal

        # vmapでバッチ処理化し、jitコンパイル
        return jit(vmap(_solve))

損失関数の実装

PFL と DFL の最大の違いは損失関数の定義にあります.

  • PFL: 予測パラメータ c_pred と真のパラメータ c_batch の MSE を最小化します.
  • DFL: 予測パラメータ c_pred に基づいて導かれた最適解 x_star が,真のパラメータ c_batch の下で生むコスト (Regret の第一項) を最小化します.
@partial(jit, static_argnames=["apply_fn"])
def loss_fn_pfl(
    params: Any,
    apply_fn: Callable,
    z_batch: jax.Array,
    c_batch: jax.Array,
) -> jax.Array:
    """PFL Loss (MSE)

    Args:
        params (Any): モデルのパラメータ
        apply_fn (Callable): モデルの適用関数
        z_batch (jax.Array): 入力(Batch, 5)
        c_batch (jax.Array): 係数パラメータ (Batch, 2)

    Returns:
        jax.Array: ロス
    """
    c_pred = apply_fn({"params": params}, z_batch)
    return jnp.mean(jnp.sum((c_pred - c_batch) ** 2, axis=-1))


@partial(jit, static_argnames=["apply_fn", "qp_problem"])
def loss_fn_dfl(
    params: Any,
    apply_fn: Callable,
    z_batch: jax.Array,
    c_batch: jax.Array,
    qp_problem: QP,
) -> jax.Array:
    """DFL Loss (Task Loss)

    Args:
        params (Any): モデルのパラメータ
        apply_fn (Callable): モデルの適用関数
        z_batch (jax.Array): 入力(Batch, 5)
        c_batch (jax.Array): 係数パラメータ (Batch, 2)
        qp_problem (QP): 問題インスタンス
    Returns:
        jax.Array: ロス
    """
    c_pred = apply_fn({"params": params}, z_batch)
    x_star = qp_problem.solve(c_pred)
    return jnp.mean(qp_problem.objective_value(x_star, c_batch))

全体の実装

本実験で使用した環境設定ファイル (pyproject.toml) と Python コード (main.py) です.クリックして展開してください.

Google Colabで実行する場合は,以下のコマンド実行して,main.py を実行してください.

!pip install jaxopt

pyproject.toml

[project]
requires-python = ">=3.9"
dependencies = [
    "flax>=0.8.5",
    "jax>=0.4.30",
    "jaxopt>=0.8.5",
    "matplotlib>=3.9.4",
    "mypy>=1.19.0",
    "optax>=0.2.4",
]

main.py

"""Prediction-focused Learning (PFL) と Decision-focused Learning (DFL) の比較実験"""

from collections.abc import Callable
from functools import partial
from typing import Any, Protocol

import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxopt
import matplotlib.pyplot as plt
import optax
from flax.training import train_state
from jax import device_get, jit, vmap


class QP(Protocol):
    """二次計画問題の共通インターフェース"""

    def objective_value(self, x: jax.Array, c: jax.Array) -> jax.Array:
        """目的関数値を計算"""
        ...

    def solve(self, c: jax.Array) -> jax.Array:
        """最適解を計算"""
        ...


class UnconstrainedQP(QP):
    """制約なし二次計画問題"""

    def __init__(self) -> None:
        """問題の初期化"""
        self.Q = jnp.array([[20.0, 1.0], [1.0, 0.2]])
        self.Q_inv = jnp.linalg.inv(self.Q)

    def objective_value(self, x: jax.Array, c: jax.Array) -> jax.Array:
        """目的関数の値を計算"""

        def _objective_value(x_i: jax.Array, c_i: jax.Array) -> jax.Array:
            quad = 0.5 * x_i.T @ self.Q @ x_i
            linear = c_i.T @ x_i
            return quad + linear

        return vmap(_objective_value)(x, c)

    @property
    def solve(self) -> Callable[[jax.Array], jax.Array]:
        """最適解を計算"""

        def _solve(c: jax.Array) -> jax.Array:
            return -self.Q_inv @ c

        return jit(vmap(_solve))


class ConstrainedQP(QP):
    """制約あり二次計画問題"""

    def __init__(self) -> None:
        """問題の初期化"""
        self.Q = jnp.array([[20.0, 1.0], [1.0, 0.2]])

        self.G = jnp.array(
            [
                [0.0, -1.0],
                [2.0, 3.0],
            ],
        )
        self.h = jnp.array([0.0, 1000.0])

        # 微分可能なQPソルバーの初期化
        self.osqp = jaxopt.OSQP()

    def objective_value(self, x: jax.Array, c: jax.Array) -> jax.Array:
        """目的関数の値を計算"""

        def _objective_value(x_i: jax.Array, c_i: jax.Array) -> jax.Array:
            quad = 0.5 * x_i.T @ self.Q @ x_i
            linear = c_i.T @ x_i
            return quad + linear

        return vmap(_objective_value)(x, c)

    @property
    def solve(self) -> Callable[[jax.Array], jax.Array]:
        """最適解を計算"""

        def _solve(c_sample: jax.Array) -> jax.Array:
            sol = self.osqp.run(
                params_obj=(self.Q, c_sample),
                params_ineq=(self.G, self.h),
            )
            return sol.params.primal

        # vmapでバッチ処理化し、jitコンパイル
        return jit(vmap(_solve))


def generate_data(key: jax.Array, n_samples: int) -> tuple[jax.Array, jax.Array]:
    """データ生成

    Args:
        key (jax.Array): 乱数キー
        n_samples (int): サンプル数

    Returns:
        tuple[jax.Array, jax.Array]: z: 特徴量(n_samples, 5), c: 係数パラメータ(n_samples, 2)
    """
    k1, k2, k3, k4 = jax.random.split(key, 4)
    # 入力
    z = jax.random.normal(k1, (n_samples, 5))
    # 非線形変換 + ノイズ
    W = jax.random.normal(k2, (5, 2))
    b = jax.random.normal(k3, (2,))
    c_clean = jnp.tanh(z @ W + b) * 100
    noise = jax.random.normal(k4, (n_samples, 2))
    c = c_clean + noise
    return z, c


class MLP(nn.Module):
    """PFL, DFLに共通の予測モデル"""

    @nn.compact
    def __call__(self, z: jax.Array) -> jax.Array:
        """Forward pass"""
        z = nn.Dense(32)(z)
        z = nn.relu(z)
        return nn.Dense(2)(z)


def create_train_state(
    key: jax.Array,
    input_shape: tuple[int, ...],
    learning_rate: float,
) -> train_state.TrainState:
    """TrainStateの初期化"""
    model = MLP()
    params = model.init(key, jnp.ones(input_shape))["params"]
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


@partial(jit, static_argnames=["apply_fn"])
def loss_fn_pfl(
    params: Any,
    apply_fn: Callable,
    z_batch: jax.Array,
    c_batch: jax.Array,
) -> jax.Array:
    """PFL Loss (MSE)

    Args:
        params (Any): モデルのパラメータ
        apply_fn (Callable): モデルの適用関数
        z_batch (jax.Array): 入力(Batch, 5)
        c_batch (jax.Array): 係数パラメータ (Batch, 2)

    Returns:
        jax.Array: ロス
    """
    c_pred = apply_fn({"params": params}, z_batch)
    return jnp.mean(jnp.sum((c_pred - c_batch) ** 2, axis=-1))


@partial(jit, static_argnames=["apply_fn", "qp_problem"])
def loss_fn_dfl(
    params: Any,
    apply_fn: Callable,
    z_batch: jax.Array,
    c_batch: jax.Array,
    qp_problem: QP,
) -> jax.Array:
    """DFL Loss (Task Loss)

    Args:
        params (Any): モデルのパラメータ
        apply_fn (Callable): モデルの適用関数
        z_batch (jax.Array): 入力(Batch, 5)
        c_batch (jax.Array): 係数パラメータ (Batch, 2)
        qp_problem (QP): 問題インスタンス
    Returns:
        jax.Array: ロス
    """
    c_pred = apply_fn({"params": params}, z_batch)
    x_star = qp_problem.solve(c_pred)
    return jnp.mean(qp_problem.objective_value(x_star, c_batch))


def evaluate(
    state: train_state.TrainState,
    z_test: jax.Array,
    c_test: jax.Array,
    qp: QP,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array], jax.Array]:
    """評価

    Args:
        state (train_state.TrainState): モデルの状態
        z_test (jax.Array): テストデータの特徴量
        c_test (jax.Array): 真の係数パラメータ
        qp (QP): 問題インスタンス

    Returns:
        tuple[jax.Array, tuple[jax.Array, jax.Array], jax.Array]: 評価結果
    """
    c_pred = state.apply_fn({"params": state.params}, z_test)
    # MSEの意味での精度
    mse = jnp.mean(jnp.sum((c_pred - c_test) ** 2, axis=-1))
    mse_c1 = jnp.mean((c_pred[:, 0] - c_test[:, 0]) ** 2)
    mse_c2 = jnp.mean((c_pred[:, 1] - c_test[:, 1]) ** 2)
    # Regretの意味での精度
    x_star_pred = qp.solve(c_pred)
    x_star_true = qp.solve(c_test)
    objective_pred = qp.objective_value(x_star_pred, c_test)
    objective_true = qp.objective_value(x_star_true, c_test)
    regret = jnp.mean(objective_pred - objective_true)
    return mse, (mse_c1, mse_c2), regret


def visualize_solutions(
    state_pfl: train_state.TrainState,
    state_dfl: train_state.TrainState,
    z_test: jax.Array,
    c_test: jax.Array,
    qp: QP,
    filename: str,
) -> None:
    """最適解の可視化

    Args:
        state_pfl (train_state.TrainState): PFLモデルの状態
        state_dfl (train_state.TrainState): DFLモデルの状態
        z_test (jax.Array): テストデータの特徴量
        c_test (jax.Array): テストデータのターゲットパラメータ
        qp (QP): 問題インスタンス
        rs (np.random.RandomState): 乱数生成器
        filename (str): 保存ファイル名
    """
    # テストデータに対する予測と対応する最適解
    c_pred_pfl = state_pfl.apply_fn({"params": state_pfl.params}, z_test)
    c_pred_dfl = state_dfl.apply_fn({"params": state_dfl.params}, z_test)

    x_star_true = qp.solve(c_test)
    x_star_pfl = qp.solve(c_pred_pfl)
    x_star_dfl = qp.solve(c_pred_dfl)

    x_t = device_get(x_star_true)
    x_pfl = device_get(x_star_pfl)
    x_dfl = device_get(x_star_dfl)

    plt.figure(figsize=(6, 6))
    plt.scatter(x_t[:, 0], x_t[:, 1], s=20, c="k", marker="*", label="True x*", alpha=0.7)
    plt.scatter(x_pfl[:, 0], x_pfl[:, 1], s=12, c="red", marker="^", label="PFL x*", alpha=0.7)
    plt.scatter(x_dfl[:, 0], x_dfl[:, 1], s=12, c="blue", marker="o", label="DFL x*", alpha=0.7)

    for i in range(x_t.shape[0]):
        plt.plot([x_pfl[i, 0], x_t[i, 0]], [x_pfl[i, 1], x_t[i, 1]], color="red", alpha=0.08)
        plt.plot([x_dfl[i, 0], x_t[i, 0]], [x_dfl[i, 1], x_t[i, 1]], color="blue", alpha=0.08)

    plt.xlabel("x0")
    plt.ylabel("x1")
    plt.title("Optimal Solutions: True vs PFL vs DFL (Test Set)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    plt.close()
    print(f"Saved visualization to {filename}")


def run_experiment(
    qp: QP,
    key_init: jax.Array,
    z_train: jax.Array,
    c_train: jax.Array,
    z_test: jax.Array,
    c_test: jax.Array,
    epochs: int,
    filename: str,
) -> None:
    """数値実験の実施

    Args:
        qp (QP): 二次計画問題
        key_init (jax.Array): モデル初期化用の乱数キー
        z_train (jax.Array): 訓練データの特徴量
        c_train (jax.Array): 訓練データの係数パラメータ
        z_test (jax.Array): テストデータの特徴量
        c_test (jax.Array): テストデータの真の係数パラメータ
        epochs (int): エポック数
        filename (str): 保存ファイル名
    """
    # モデル初期化
    state_pfl = create_train_state(key_init, (1, 5), LR)
    state_dfl = create_train_state(key_init, (1, 5), LR)

    # gradの定義
    grad_loss_fn_pfl = jax.value_and_grad(loss_fn_pfl)
    grad_loss_fn_dfl = jax.value_and_grad(
        lambda p, af, z, c: loss_fn_dfl(p, af, z, c, qp),
    )

    # 学習ループ
    for _ in range(epochs):
        # PFL Update
        _, grads_pfl = grad_loss_fn_pfl(
            state_pfl.params,
            state_pfl.apply_fn,
            z_train,
            c_train,
        )
        state_pfl = state_pfl.apply_gradients(grads=grads_pfl)

        # DFL Update
        _, grads_dfl = grad_loss_fn_dfl(
            state_dfl.params,
            state_dfl.apply_fn,
            z_train,
            c_train,
        )
        state_dfl = state_dfl.apply_gradients(grads=grads_dfl)

    # 評価
    mse_pfl, mse_pfl_components, regret_pfl = evaluate(
        state_pfl,
        z_test,
        c_test,
        qp,
    )
    mse_dfl, mse_dfl_components, regret_dfl = evaluate(
        state_dfl,
        z_test,
        c_test,
        qp,
    )

    # 結果をテーブル形式で表示
    header = f"| {'Model':^7} | {'MSE (overall)':^15} | {'MSE (c0)':^12} | {'MSE (c1)':^12} | {'Regret':^12} |"
    separator = "+" + "-" * 9 + "+" + "-" * 17 + "+" + "-" * 14 + "+" + "-" * 14 + "+" + "-" * 14 + "+"

    print(separator)
    print(header)
    print(separator)

    row_pfl = (
        f"| {'PFL':^7} | {mse_pfl:^15.5f} | {mse_pfl_components[0]:^12.5f} | "
        f"{mse_pfl_components[1]:^12.5f} | {regret_pfl:^12.5f} |"
    )
    print(row_pfl)

    row_dfl = (
        f"| {'DFL':^7} | {mse_dfl:^15.5f} | {mse_dfl_components[0]:^12.5f} | "
        f"{mse_dfl_components[1]:^12.5f} | {regret_dfl:^12.5f} |"
    )
    print(row_dfl)

    print(separator)

    # 可視化
    visualize_solutions(
        state_pfl,
        state_dfl,
        z_test,
        c_test,
        qp,
        filename,
    )


if __name__ == "__main__":
    SEED = 42
    N_SAMPLES = 1000
    TEST_RATIO = 0.2
    EPOCHS = 5000
    LR = 0.005
    key = jax.random.PRNGKey(SEED)
    key_init = jax.random.PRNGKey(SEED + 1)

    # データ準備
    z, c = generate_data(key, N_SAMPLES)
    n_train = int(N_SAMPLES * (1 - TEST_RATIO))
    z_train, z_test = z[:n_train], z[n_train:]
    c_train, c_test = c[:n_train], c[n_train:]

    # 制約なし
    print("\n=== Unconstrained QP Experiment ===")
    run_experiment(
        UnconstrainedQP(),
        key_init,
        z_train,
        c_train,
        z_test,
        c_test,
        EPOCHS,
        "solutions_unconstrained.png",
    )
    # 制約あり
    print("\n=== Constrained QP Experiment ===")
    qp = ConstrainedQP()
    run_experiment(
        ConstrainedQP(),
        key_init,
        z_train,
        c_train,
        z_test,
        c_test,
        EPOCHS,
        "solutions_constrained.png",
    )

6. おわりに

本記事では,JAXopt を用いて Decision-Focused Learning (DFL) を実装し,従来の Prediction-Focused Learning (PFL) と比較を行いました.実験の結果,DFL は以下の特性を持つことが確認できました.

  • 目的関数の感度を利用する: 目的関数値への影響が大きいパラメータを重点的に学習する(Case 1).
  • 最適化問題の構造を利用する: 制約条件を満たす解を得るために,あえてバイアスのかかった予測を行うことがある(Case 2).

実際のビジネス課題では,制約条件が複雑であったり,目的関数が非対称なコスト構造を持っていたりすることが一般的です.そのようなケースにおいて,単に予測精度を追うのではなく,最終的な意思決定に焦点を当てる DFL は非常に強力な選択肢となります.

JAXopt のような微分可能最適化ライブラリの登場により,こうした高度な手法も比較的容易に実装できるようになりました.

この記事を読んでいただき,少しでも皆様の参考となれば幸いです!


  1. Prediction-then-Optimize (PTO) と呼ばれることもあります.
  2. この例のような単純な設定であれば,実際には損失関数を工夫したりデータに重み付けをしたりすることで PFL でも対処可能です.ここでは直感的な理解のために例示しています.DFL の真価は,単純な重み付け等では表現できない,より複雑な構造を持つ問題に対して発揮されます.