NTTドコモR&Dの技術ブログです。

Two-Tower レコメンドをちゃんと理解する ― TFRS / 対照学習 / HNSW まで

ドコモR&D戦略部、川畠雄司です。普段は仮想ユーザーモデルを用いて実ユーザーの行動を予測する「仮想マーケティング」技術の研究や、その社会実装に向けた開発・運用に従事しています。

本記事では「Two-Towerモデル」について解説します。 Google(YouTube)をはじめとするビッグテック企業の推薦システムでも採用されているこのアーキテクチャは、大規模なデータからユーザーの好みを高速かつ高精度にマッチングさせる技術として、近年のデファクトスタンダードになりつつ?あります。今回はその仕組みと実装について紐解いていきます

1. そもそもレコメンドとは

レコメンドでは、「どのユーザーにどのアイテムを出すか」をスコア関数 $ f(u, i; \theta) $で決め、スコアの高い順に上位 $ K $ 件を返します。

古典的な協調フィルタリングは

$$f(u, i) = p_{\, u}^\top q_{\, i}$$

のように、ユーザー埋め込み $ p_{\, u}$ とアイテム埋め込み $ q_{\, i}$ を 過去ログだけから学習するため、新ユーザー/新アイテムには弱いです(コールドスタート)。


2. Two-Tower モデルとは?

Two-Tower(または dual-encoder)は,ユーザ側とアイテム側を別々のネットワークで埋め込みベクトルに変換し,その内積をスコアとして使うモデルです。

  • Query tower

    • ユーザ ID,属性,行動特徴などを入力
    • ベクトル $ u(x_i, \theta_q)\in\mathbb{R}^d $ に変換
  • Candidate tower

    • アイテム ID,カテゴリ,テキスト特徴などを入力
    • ベクトル $ v(y_j, \theta_c)\in\mathbb{R}^d$ に変換

推論時には,あるクエリ $x$ について $$
s(x, y) = u(x)^\top v(y) $$ をスコアとして,大量のアイテム集合から上位 $K$ 個を取り出します。

この構造の良いポイントは

  • コールドスタート問題に強い(特徴量さえ入ればベクトルを出せるため)
  • アイテム側のベクトル $ v(y)$ をあらかじめ計算して保存しておける
  • スコアが単なる「ベクトル内積」なので,近似最近傍探索(ANN) と相性が良い
  • リッチな特徴量を柔軟に入力として組み込める点(マルチモーダル等,ベクトル化できればなんでも可)

ということです。これによって「検索エンジンのようなスピードでレコメンド」が可能になります。

また、このとき各タワーのニューラルネットワークアーキテクチャが何かは自由です。 MLP、CNN、RNN/LSTM、Transformerなど、様々なアーキテクチャをデータの種類や目的に応じて選択できます。

two-tower モデル

3. TFRS(TensorFlow Recommenders)による Two-Tower

3.1 TFRS とは?

TFRS は TensorFlow 上でレコメンドモデルを作るためのライブラリで,

  • モデル定義(Two-Tower / Ranking 等)
  • ロス関数(対照学習,ランキングロス)
  • 評価メトリクス(Recall@K など)
  • 大量候補からの ANN 検索

を一通り提供してくれます。

Two-Tower の場合,TFRS では典型的に次のようなクラス構成になります。

class MovielensModel(tfrs.models.Model):

  def __init__(self, layer_sizes):
    super().__init__()

    # Query / Candidate tower を定義
    self.query_model = QueryModel(layer_sizes)
    self.candidate_model = CandidateModel(layer_sizes)

    # Retrieval タスク(対照学習+評価メトリクス)
    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.candidate_model),
        ),
    )

  def compute_loss(self, features, training=False):
    query_embeddings = self.query_model({
        "user_id": features["user_id"],
        "timestamp": features["timestamp"],
    })
    movie_embeddings = self.candidate_model(features["movie_title"])

    return self.task(
        query_embeddings, movie_embeddings, compute_metrics=not training)

query_embeddings が $ u(x_{\, i})$ ,movie_embeddings が $ v(y_{\, i})$ に相当し,tfrs.tasks.Retrieval が学習ロスと評価をまとめて担当します。

4. 対照学習と Deep Metric Learning

4.1 バッチ内負例による対照学習

Two-Tower の学習では,「クエリと正しいアイテムは近く,それ以外は遠く」 という制約を課すために,対照学習(contrastive learning) の枠組みがよく使われます。

1 ミニバッチ B に,過去に行動があったペア $ (x_i, y_i)$ を $ i\in B$ 個持ってくるとします。Query tower と Candidate tower で埋め込みを計算すると,

$$ u_{\, i} = u(x_{\, i}), v_{\, i} = v(y_{\, i})$$

となり,これらからロジット行列を作ります。

$$ logits_{\, ij} = u_{\, i}^\top v_{\, j}$$

ここで

  • 対角成分 $(i=j)$: 正例ペア
  • 非対角成分 $(i\neq j)$: バッチ内で勝手に作った「負例」

になります。各行ごとに Softmax を取り,正例に対するクロスエントロピーをロスとすると,

\displaystyle{
L = - \sum_{i \in B} \log \left(
    \frac{\exp(u_i^\top v_i)}
         {\sum_{j \in B} \exp(u_i^\top v_j)}
\right)
}

という InfoNCE 型のロス になります。TFRS の Retrieval タスクのデフォルトはほぼこれと同じです。

バッチサイズを大きくすると「疑似負例」の数が増え、区別すべきアイテムが増えるので性能が上がりやすいという性質があります。

4.2 Deep Metric Learning 的な見方

このロスは

  • 似てほしいペア $(x_i, y_i),(i=j)$ の距離(負の内積)を小さく
  • 似てほしくない$(x_i, y_j), (i\neq j)$の距離を大きく

する,いわゆる deep metric learning の一種と見ることができます。Triplet loss や N-pair loss と同じ系譜で,Two-Tower はその「アンカー=クエリ,ポジティブ=クリックアイテム,ネガティブ=他アイテム」という特殊ケースと考えると理解しやすいです。

5. Negative Sampling と Popularity Bias

5.1 バッチ内負例の歪み

バッチ内負例を使うと,人気アイテムほど負例として登場しやすいという問題が生じます。学習データには人気アイテムが何度も登場する結果, あるミニバッチで人気アイテムは「正例にもなるし,別のクエリから見ると負例にもなる」Softmax ロスでは,負例として出てきたときにスコアを下げる方向の勾配が強く働き,結果として人気アイテムのスコアが過小評価され,ニッチなアイテムのスコアが 過大評価 されるという sampling bias が生じます。

5.2 Sampling Bias の代表的な対処法

このようにバッチ内負例だけに頼ると,人気アイテムのスコアが不当に下がり,ニッチアイテムが過大評価されるという sampling bias が入り込みます。実務では,このバイアスを軽減するために,いくつかの代表的な対処法が使われます。ここでは,Two-Tower / TFRS 文脈でよく登場するものをざっくり整理しておきます。

カテゴリ 手法の例 何をしているか メリット デメリット / 注意点 Two-Tower / TFRS での位置づけ
重要度補正 $log p_j $補正(Sampling-Bias-Corrected Neural Modeling) サンプリング分布 $p_j$ に対する重要度補正として,ロジットを $u_i^\top v_j - \log p_j $に置き換える。TFRS では candidate_sampling_probability を渡すだけで実装可能。 理論的にきれいで,popular 過小評価 / niche 過大評価を直接補正できる。 $p_j$ を推定する前処理が必要。サンプリング分布が変わったら再計算が必要。 TFRS 推奨のやり方
サンプリング分布の調整 popularity-aware sampling / ほぼ一様サンプリング 「人気アイテムほど負例に出やすい」状態を緩和するために,サンプリング確率自体を調整する 学習パイプライン側の工夫だけで導入できる。 真の分布とのズレは残る, log p_j 補正ほどバイアスを打ち消せないことも。 TFRS でも candidates の作り方次第で実現可能。$log p_j $補正と併用されることも多い。
Hard Negative 系 近いアイテムを負例として優先サンプリング(hard negative mining) 類似カテゴリ・ジャンルなど「紛らわしいアイテム」を意図的に負例として選び,識別能力を高める。 モデルが「本当に紛らわしいアイテム」を区別するようになり,ランキング性能向上が期待できる。 popular item が hard negative に選ばれやすく,補正なしだと下げすぎになりうる。 Two-Tower では対照学習の負例を工夫する形で導入される。
ロス重み付け・正則化 popularity に応じた loss weight / embedding 正則化 人気アイテムに対する負例勾配の重みを下げる・ニッチアイテムを強めに正則化するなど,loss 側でバランスを取る。 サンプリングロジックを変えなくても試せる。 理論的には ad-hoc なことも多く,ハイパーパラメータ調整が必要。 TFRS では custom loss / custom task を書くことで実装可能。log p_j 補正に追加で fine-tuning 的に入れるケースも。
Pairwise / Ranking ロス BPR, pairwise logistic など + 負例サンプリング工夫 softmax ベースではなく,「正例が負例より上に来る」pairwise ランキングロスで学習する。負例サンプリング戦略と組み合わせて bias を抑える。 レコメンドの最終目的に近い形でロスを定義できる。softmax の温度などに悩まされにくい。 「どの負例をどうサンプルするか」の問題は残る。 TFRS の Ranking タスクや custom training loop で実装されることが多い。
Debiased Contrastive / Multi-task debiased CL, popularity から独立な表現を学習するマルチタスク 「人気だから観測された」バイアスをロス設計や補助タスクによって明示的に抑える。 理論・実装次第で,ログ自体の popularity bias をより広い意味で軽減できる可能性がある。 手法が多様で複雑になりがち。 Two-Tower をベースに拡張する形で研究・実務の両方で使われ始めている領域。

実務的には,まずは TFRS 標準の$ log p_j $補正(candidate_sampling_probability)を入れるのが一番コスパがよく,その上で必要に応じて

  • サンプリング分布の見直し
  • hard negative の導入
  • pairwise / ranking ロスや debiased CL への拡張

といった追加施策を検討していく,という流れになることが多いです。

6. 近似最近傍探索と HNSW

Two-Tower では推論時に,

  • クエリ $x$ を埋め込み $u(x)$ に変換
  • ベクトル空間上で「近い」アイテム $v(y)$ をたくさん取り出す

という処理を高速に行う必要があります。全アイテムを総当たりで内積計算するのは現実的ではないため,近似最近傍探索(Approximate Nearest Neighbor; ANN) を使います。 代表的な手法の一つが HNSW (Hierarchical Navigable Small World) です。

出典: Pinecone - Hierarchical Navigable Small Worlds (HNSW)

6.1 HNSW の直感的な動き

HNSW は「多層構造のグラフ」を使って近傍探索を行います。

  • 最上層(第2階層など)
    • ノード数が少ない粗いグラフ
    • ここからスタートする(図の緑の点など)
    • 目的地点に近づくようにグラフ上を貪欲に移動
    • 「今いるノードの近傍のうち,クエリに最も近いノード」に飛ぶ
    • これを繰り返していくと,だんだん目的地の近くへ
  • 一つ下の階層に降りる
    • 目的地にある程度近づいたら,下の階層に降りる
    • 下の階層ほどノードが多く,細かい探索になる
    • 最下層まで 2–3 を繰り返す
  • 最下層(第0階層)は全ての点を含む密なグラフ
    • 最終的にはここで近傍探索をして候補を返す

直感的には,

「高層階では大雑把に目的地の方向を決め,下の階層へ降りるほど細かく修正していく」という「階層付き山登り探索」をしているイメージです。 Two-Tower では,アイテム埋め込み $v(y)$ を HNSW のノードとしてインデックスしておき,推論時に $u(x)$ をクエリとして HNSW 検索を行うことで,

  • 数千万〜数億アイテム規模でも
  • 十数ミリ秒オーダーで上位 $K$ を取ってくる

ことが可能になります。

8. 実際に Two-Tower を動かしてみる

ここからは、実際に Two-Tower を動かしてin-batch negative sampling がどんなバイアスを生むのかそれをどう緩和できるのかを、MovieLens 100K を題材に覗いてみます。(チュートリアル参照

8.1 データと前処理

使ったデータは MovieLens 100K のレーティングログです。今回はレーティングが 3.5 以上の (user_id, movie_title) だけを取り出し、「そのユーザがその映画を好んだ」という 正例インタラクション とみなす という、すごく素朴な implicit-feedback 設定です。この条件で残る正例はだいたい 5.5 万行です。 そこから

  • 80% を学習用
  • 20% をテスト用

にランダム分割しています。

さらに、sampling bias を見るために「映画の人気度」をざっくり 3 つの帯に分けました。ある映画が何回ログに登場したかを単純にカウントし、

出現回数が少ない順から

  • 下位 70% → Tail
  • 上位 30%〜10% → Mid
  • 上位 10% → Head(超人気作)

というラベルを付けています。

テストデータの内訳はだいたい

  • Tail: 18%
  • Mid : 34%
  • Head: 48%

で、学習データ側もほぼ同じ比率です。 MovieLens 100K では「人気作がかなり多い」設定になっている、というのがまず一つ目の前提です。 この頻度から、各アイテムが「学習データの中でどれくらいの確率で登場するか」も計算してあります。 後で出てくる「log p_j 補正」はここから使っていますが、詳細は全部コード側に押し込み、本文では「人気度から計算している」とだけ覚えておけば十分です。

8.2 モデル構成とロスの違い

モデル自体はかなりシンプルです。

  • 入力は user_id と movie_title のみ
  • それぞれ Embedding → MLP(2 層)で 32 次元ベクトルに変換
  • 最後に L2 正規化して、ユーザベクトルと映画ベクトルの「内積」をスコアとして扱う

という、ごく標準的な Two-Tower です。ユーザ側/アイテム側でクラスを分けず、同じクラスの中に「ユーザ埋め込み用のネット」「アイテム埋め込み用のネット」を持たせています。

学習設定はざっくり

  • バッチサイズ 1024
  • エポック数 300
  • Optimizer は Adagrad(学習率 0.1)

です。1 エポックで学習用データをちょうど 1 回なめるので、 「300 回同じデータを見直しながら、少しずつパラメータを更新している」イメージです。

同じ Two-Tower に対して、ロスとサンプルの選び方 だけを変えた 4 パターンを比較しました。

  • Baseline
    いちばん素直な in-batch negative sampling
    1 バッチの中で、正解ペア以外はぜんぶ「負例」として扱うやり方
  • Bias-corrected
    Baseline と同じロジックですが、映画ごとに計算した「登場確率」から log p_j を求め、
    スコアに「− log p_j」のハンデを入れて学習します
    「人気だから頻繁に登場しているだけの映画」はペナルティを食らい、
    ユーザーの嗜好によるスコアだけで勝負しろ、という思想の手法です
  • Pop-aware sampling
    ロスの計算は Baseline と同じ
    ただし 学習に使うインタラクションの選び方 を変え、
    人気映画の行はわざと取りにくくし、ニッチ映画の行を多めにサンプルする
  • Loss weighting
    学習に使うインタラクションはそのまま
    代わりに、ロスの中で「人気映画の重みを小さく」「ニッチ映画の重みをやや大きく」する

ざっくり言うと、

  • Bias-corrected → 「スコアの計算式側で補正」
  • Pop-aware → 「どの行を学習に使うかを変える」
  • Loss weighting → 「ロスの重みでバランスをいじる」

という三つ巴で、すべて in-batch negative をベースにした小改造になっています。

8.3 評価の仕方

評価指標は全部「Recall@100」です。 テストの各 (ユーザ, 映画) について

  • そのユーザからクエリベクトルを計算
  • 全映画とのスコアを計算して、上位 100 本をリコメンド候補とみなす
  • その 100 本の中に「本当にそのユーザが見ていた映画」が入っているかどうかを見る

これを全テストデータで平均したものが Recall@100
というシンプルな指標です。

さらに、各モデルが返した Top-100 の中に、Tail / Mid / Head の映画がどれくらいの割合で含まれるかも集計しています。
これを見ると、「このモデルは人気帯のどこをどれくらい推しているのか」が直感的に分かります。

8.4 実験結果のざっくり読み解き

8.4.1 Recall vs epoch(raw スコア)

最初のグラフ(Recall vs epoch, raw score)を見ると、こんな傾向が出ています。

  • Bias-corrected が常に一番上にいて、
    Baseline に対してだいたい 1.5 倍くらい高い Recall を出している
  • Pop-aware / Loss weighting は Baseline とほぼ同じか、わずかに下

MovieLens のように人気映画が多いデータでは、「raw スコアでの精度だけ見れば、log p_j 補正が一番コスパがいい」という、結果になっています。

Recall vs epoch, raw score

8.4.2 埋め込みノルムと人気

  • Baseline では、人気が高くなるほど埋め込みノルムもゆるやかに増えていて、
    人気映画がどんどん「長いベクトル」になっていくのが分かります。
    これは in-batch negative で何度も負例として見せられ、
    他のアイテムから引き離そうとする勾配が蓄積した結果と考えられます。
  • Bias-corrected では、Tail 〜 Mid あたりでいったんノルムが落ち、
    Head 帯で少し持ち上がるような 逆 U 字 っぽい形になっています。
    「とりあえず人気順にどんどんノルムが伸びていく」というブラックホール状態からは脱却しているのが分かります。
Baseline :人気が高くなるほど埋め込みノルムもゆるやかに増える
log p_j 補正:“人気だからベクトルが大きくなるという傾向がかなり抑えられる

8.4.3 人気帯ごとの Recall@100

人気帯別の Recall@100(raw スコア版)の棒グラフからは、次のようなストーリーが読み取れます。

  • Baseline
    Tail / Mid / Head でそこそこ均等に当てているが、Head の精度はやや低め
  • Bias-corrected
    Head の Recall が一気に跳ね上がり、Tail / Mid は少し犠牲になる
    「とにかく人気作を外したくない」モデルになっている
  • Pop-aware / Loss weighting
    全体としては Baseline に近いが、やや Tail 寄りで穏やかな挙動

つまり

  • Baseline → 「人気作を過小評価しがち」
  • Bias-corrected → 「超人気作をかなり優遇」
  • Pop-aware / Loss weighting → その中間〜ちょい Tail 寄り

というイメージです。

人気帯ごとの Recall@100

8.4.4 Top-100 の人気分布(raw スコア)

  • データ本来の割合
    Tail 18%, Mid 34%, Head 48%
  • Baseline / Pop-aware / Loss-weighting の Top-100
    Tail 60% 超, Head 10% ちょっと
    → Tail を 3〜4 倍も出しすぎている
  • Bias-corrected の Top-100
    Tail ほぼ 0, Head 70% 超
    → Head 偏重モードに振り切れた

「人気の分布をどれだけ歪ませてしまっているか」という意味では、Baseline は ニッチを盛りすぎ、Bias-corrected は 人気を盛りすぎ という、両極端な絵になっています。

Top-100 の人気分布(raw スコア)

8.5 今回の実験から言えること

ここまでの結果を、雑にまとめるとこんな感じです。

  • in-batch negative sampling だけに頼ると学習時の負例として人気アイテムが出まくる。その結果Tail のアイテムを出しすぎてHead を外しがちなモデルになりやすい
  • log p_j 補正(Bias-corrected)を入れると人気アイテムの「負例としての出現しやすさ」を割り引ける。MovieLens では raw スコアでの Recall がかなり改善し,Top-100 も Head 寄りに戻ってくる。ただし、それでも「人気だから出ている部分」を完全には切り離せないので公平性やロングテール重視で評価したい場合はu・v − log p_j で評価し直す こともセットで考えた方が良い

「in-batch negative + そのままのスコア」で学習・評価していると 実はかなり popularity bias に振り切った世界を見ている可能性がある というのが今回の実験で確認できたポイントです。

その上で

  • クリック率や視聴時間のような「ビジネス指標」を最大化したいなら Bias-corrected を基本線にしつつ、
  • 「人気に寄りすぎず Tail もちゃんと拾いたい」なら debiased なスコアでの評価や sampling / loss の工夫もセットでやる

というのが、Two-Tower でネガティブサンプリングをいじるときの実務的な落としどころになりそうです。

Two-Towerモデルは、シンプルながらも「負例をどう扱うか」で挙動が大きく変わる奥深いモデルです。本記事が、皆さんの推薦システム構築の一助となれば幸いです🎄🎄

ここをクリックするとコードが開きます(実験コード,チュートリアルに足す形で動かしてください)

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

from collections import Counter




EMBED_DIM = 32 
HIDDEN_LAYERS = (64, 32) 
BATCH_SIZE = 1024   
EPOCHS = 300              
EVAL_EVERY = 5              # 5 epoch ごとに評価
TOP_K = 100                 # Recall@100 くらい
SEED = 42
N_MAX = None


tf.random.set_seed(SEED)
np.random.seed(SEED)


# 1. MovieLens100K をロード & 前処理(subset)

print("Loading MovieLens 100K...")

ds_raw = tfds.load("movielens/100k-ratings", split="train")

rows = []
for x in tfds.as_numpy(ds_raw):
    if x["user_rating"] >= 3.5:  # 正例とみなす閾値
        user = x["user_id"].decode("utf-8")
        movie = x["movie_title"].decode("utf-8")
        rows.append((user, movie))

print(f"Num positive interactions (full): {len(rows)}")

# 軽量化のためランダムに N_MAX 件に絞る
if (N_MAX is not None) and (len(rows) > N_MAX):
    perm = np.random.permutation(len(rows))[:N_MAX]
    rows = [rows[i] for i in perm]

print(f"Num positive interactions (used): {len(rows)}")

user_ids_np = np.array([u for u, _ in rows])
movie_titles_np = np.array([m for _, m in rows])
N = len(user_ids_np)

# 語彙
unique_user_ids = np.unique(user_ids_np)
unique_movie_titles = np.unique(movie_titles_np)

user_vocabulary = tf.keras.layers.StringLookup(
    vocabulary=unique_user_ids,
    mask_token=None,
)
movie_vocabulary = tf.keras.layers.StringLookup(
    vocabulary=unique_movie_titles,
    mask_token=None,
)

num_users = user_vocabulary.vocabulary_size()
num_items = movie_vocabulary.vocabulary_size()

print("num_users:", num_users, "num_items (with OOV):", num_items)

item_indices_np = movie_vocabulary(tf.constant(movie_titles_np)).numpy()


# 2. アイテム頻度 & p_j, popularity bin


print("Computing item frequencies and p_j...")

item_freq = np.zeros(num_items, dtype=np.float32)
for idx in item_indices_np:
    item_freq[idx] += 1.0

total_count = float(item_freq.sum())
p_j = item_freq / max(total_count, 1.0)
p_j = np.clip(p_j, 1e-8, 1.0)
log_p_j = tf.constant(np.log(p_j), dtype=tf.float32)

# popularity bin(Tail / Mid / Head): おおざっぱでOK
nonzero = item_freq[item_freq > 0]
q_mid = np.quantile(nonzero, 0.7)
q_head = np.quantile(nonzero, 0.9)

pop_bins = np.zeros_like(item_freq, dtype=np.int32)   # 0: tail
pop_bins[item_freq >= q_mid] = 1                      # 1: mid
pop_bins[item_freq >= q_head] = 2                     # 2: head
NUM_BINS = 3


# 3. Train / Test split

perm = np.random.permutation(N)
train_size = int(0.8 * N)
train_idx = perm[:train_size]
test_idx = perm[train_size:]

def make_dataset_from_indices(indices, batch_size=None, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices({
        "user_id": user_ids_np[indices],
        "movie_title": movie_titles_np[indices],
    })
    if shuffle and batch_size is not None:
        ds = ds.shuffle(10_000, reshuffle_each_iteration=True)
    if batch_size is not None:
        ds = ds.batch(batch_size)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

test_ds = make_dataset_from_indices(test_idx, batch_size=512, shuffle=False)


# 4. Two-Tower モデル


class TwoTower(tf.keras.Model):
    def __init__(self, embedding_dim=16, hidden_layers=(32,)):
        super().__init__()
        self.user_lookup = user_vocabulary
        self.item_lookup = movie_vocabulary

        self.user_embedding = tf.keras.layers.Embedding(num_users, embedding_dim)
        self.item_embedding = tf.keras.layers.Embedding(num_items, embedding_dim)

        self.user_mlp = tf.keras.Sequential(
            [tf.keras.layers.Dense(h, activation="relu") for h in hidden_layers]
        )
        self.item_mlp = tf.keras.Sequential(
            [tf.keras.layers.Dense(h, activation="relu") for h in hidden_layers]
        )

    def encode_user(self, user_ids):
        idx = self.user_lookup(user_ids)
        u = self.user_embedding(idx)
        u = self.user_mlp(u)
        u = tf.math.l2_normalize(u, axis=-1)
        return u

    def encode_item(self, movie_titles):
        idx = self.item_lookup(movie_titles)
        v = self.item_embedding(idx)
        v = self.item_mlp(v)
        v = tf.math.l2_normalize(v, axis=-1)
        return v, idx

    def encode_item_raw(self, movie_titles):
        idx = self.item_lookup(movie_titles)
        v = self.item_embedding(idx)
        v = self.item_mlp(v)
        return v, idx  # 正規化しない

    def call(self, inputs):
        u = self.encode_user(inputs["user_id"])
        v, v_idx = self.encode_item(inputs["movie_title"])
        return u, v, v_idx


# 5. In-batch Loss の4バリエーション

def inbatch_loss_baseline(u, v):
    logits = tf.matmul(u, v, transpose_b=True)  # [B, B]
    labels = tf.range(tf.shape(logits)[0])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits
    )
    return tf.reduce_mean(loss)

def inbatch_loss_bias_corrected(u, v, v_idx):
    logits = tf.matmul(u, v, transpose_b=True)
    batch_log_p = tf.gather(log_p_j, v_idx)          # [B]
    correction = tf.expand_dims(batch_log_p, axis=0) # [1, B]
    logits = logits - correction

    labels = tf.range(tf.shape(logits)[0])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits
    )
    return tf.reduce_mean(loss)

# Loss Weighting(おまけ枠)
# 正例アイテムの popularity に応じて per-example loss を重み付け
def inbatch_loss_loss_weighting(u, v, v_idx, beta=-0.25):
    logits = tf.matmul(u, v, transpose_b=True)
    labels = tf.range(tf.shape(logits)[0])
    per_example = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits
    )
    freq = tf.gather(item_freq, v_idx)   # 正例アイテムの頻度
    weights = (freq + 1.0) ** beta       # popular ほど重み↓
    weights = weights / tf.reduce_mean(weights)
    return tf.reduce_mean(per_example * weights)

# popularity-aware sampling 用に、train index を再サンプリング
def make_popaware_indices(base_indices, alpha=-0.5):
    item_idx_for_base = item_indices_np[base_indices]
    freq = item_freq[item_idx_for_base]
    weights = freq ** alpha
    weights = weights / weights.sum()
    positions = np.random.choice(
        len(base_indices),
        size=len(base_indices),
        replace=True,
        p=weights,
    )
    return base_indices[positions]


# 6. Recall@K + popularity 別評価


def evaluate_model(model, k=TOP_K, use_bias_corrected=False):
    """
    use_bias_corrected=False: スコア = u·v       (raw モード)
    use_bias_corrected=True : スコア = u·v - log p_j (debiased モード)
    """
    vocab = tf.constant(movie_vocabulary.get_vocabulary())
    item_emb, _ = model.encode_item(vocab)          # [num_items, D]
    item_emb_T = tf.transpose(item_emb)             # [D, num_items]

    if use_bias_corrected:
        log_p = tf.reshape(log_p_j, (1, -1))        # [1, num_items]

    hits_total = 0
    count_total = 0
    hits_bin = np.zeros(NUM_BINS, dtype=np.int64)
    count_bin = np.zeros(NUM_BINS, dtype=np.int64)

    for batch in test_ds:
        u = model.encode_user(batch["user_id"])             # [B, D]
        true_idx = movie_vocabulary(batch["movie_title"])   # [B] int64

        scores = tf.matmul(u, item_emb_T)                   # [B, num_items]
        if use_bias_corrected:
            scores = scores - log_p                         # u·v - log p_j

        topk = tf.math.top_k(scores, k=k).indices           # [B, K] int32
        topk = tf.cast(topk, dtype=true_idx.dtype)          # int64 に揃える

        true_idx_exp = tf.expand_dims(true_idx, axis=1)
        hit = tf.reduce_any(tf.equal(true_idx_exp, topk), axis=1).numpy()
        true_idx_np = true_idx.numpy()
        bins_np = pop_bins[true_idx_np]

        hits_total += hit.sum()
        count_total += hit.size

        for b in range(NUM_BINS):
            mask = (bins_np == b)
            count_bin[b] += mask.sum()
            hits_bin[b] += (hit & mask).sum()

    overall = hits_total / max(count_total, 1)
    recall_by_bin = hits_bin / np.maximum(count_bin, 1)

    return overall, recall_by_bin






# 7. 学習ループ(1モデル分)
def train_one_model(name, loss_type, sampler="baseline"):
    print(f"\n=== Training {name} ===")

    model = TwoTower(embedding_dim=EMBED_DIM, hidden_layers=HIDDEN_LAYERS)
    optimizer = tf.keras.optimizers.Adagrad(0.1)

    history_epochs = []

    # raw / debiased それぞれの履歴
    history_recall_raw = []
    history_recall_deb = []
    history_recbins_raw = []
    history_recbins_deb = []

    for epoch in range(1, EPOCHS + 1):
        if sampler == "baseline":
            epoch_idx = train_idx
        elif sampler == "popaware":
            epoch_idx = make_popaware_indices(train_idx, alpha=-0.5)
        else:
            raise ValueError("Unknown sampler")

        train_ds = make_dataset_from_indices(epoch_idx,
                                             batch_size=BATCH_SIZE,
                                             shuffle=True)

        for batch in train_ds:
            with tf.GradientTape() as tape:
                u, v, v_idx = model(batch)
                if loss_type == "baseline":
                    loss = inbatch_loss_baseline(u, v)
                elif loss_type == "bias":
                    loss = inbatch_loss_bias_corrected(u, v, v_idx)
                elif loss_type == "loss_weight":
                    loss = inbatch_loss_loss_weighting(u, v, v_idx)
                else:
                    raise ValueError("Unknown loss_type")

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if epoch % EVAL_EVERY == 0:
            # (1) raw スコア
            rec_raw, recbins_raw = evaluate_model(
                model, k=TOP_K, use_bias_corrected=False
            )
            # (2) debiased スコア
            rec_deb, recbins_deb = evaluate_model(
                model, k=TOP_K, use_bias_corrected=True
            )

            history_epochs.append(epoch)
            history_recall_raw.append(rec_raw)
            history_recall_deb.append(rec_deb)
            history_recbins_raw.append(recbins_raw)
            history_recbins_deb.append(recbins_deb)

            print(
                f"[{name}] epoch {epoch:2d} "
                f"loss={loss.numpy():.4f}  "
                f"Recall_raw@{TOP_K}={rec_raw:.4f}  "
                f"Recall_deb@{TOP_K}={rec_deb:.4f}"
            )

    history_recall_raw = np.array(history_recall_raw)
    history_recall_deb = np.array(history_recall_deb)
    history_recbins_raw = np.stack(history_recbins_raw, axis=0)
    history_recbins_deb = np.stack(history_recbins_deb, axis=0)

    return (model, history_epochs,
            history_recall_raw, history_recall_deb,
            history_recbins_raw, history_recbins_deb)



# 8. 4モデルを一気に学習
results = {}

(model_base,
 ep_base,
 rec_base_raw,
 rec_base_deb,
 recbins_base_raw,
 recbins_base_deb) = train_one_model(
    name="baseline",
    loss_type="baseline",
    sampler="baseline",
)
results["baseline"] = (model_base, ep_base,
                       rec_base_raw, rec_base_deb,
                       recbins_base_raw, recbins_base_deb)

(model_bias,
 ep_bias,
 rec_bias_raw,
 rec_bias_deb,
 recbins_bias_raw,
 recbins_bias_deb) = train_one_model(
    name="bias_corrected",
    loss_type="bias",
    sampler="baseline",
)
results["bias"] = (model_bias, ep_bias,
                   rec_bias_raw, rec_bias_deb,
                   recbins_bias_raw, recbins_bias_deb)

(model_pop,
 ep_pop,
 rec_pop_raw,
 rec_pop_deb,
 recbins_pop_raw,
 recbins_pop_deb) = train_one_model(
    name="popaware_sampling",
    loss_type="baseline",
    sampler="popaware",
)
results["popaware"] = (model_pop, ep_pop,
                       rec_pop_raw, rec_pop_deb,
                       recbins_pop_raw, recbins_pop_deb)

(model_lw,
 ep_lw,
 rec_lw_raw,
 rec_lw_deb,
 recbins_lw_raw,
 recbins_lw_deb) = train_one_model(
    name="loss_weighting",
    loss_type="loss_weight",
    sampler="baseline",
)
results["loss_weight"] = (model_lw, ep_lw,
                          rec_lw_raw, rec_lw_deb,
                          recbins_lw_raw, recbins_lw_deb)



# 9. 図1: Recall vs epoch(raw / debiased)


# (1) raw スコア(u·v)で評価したとき
plt.figure(figsize=(6,4))
plt.plot(ep_base, rec_base_raw, label="baseline")
plt.plot(ep_bias, rec_bias_raw, label="bias-corrected")
plt.plot(ep_pop, rec_pop_raw, label="pop-aware sampling")
plt.plot(ep_lw,   rec_lw_raw,   label="loss weighting")
plt.xlabel("epoch")
plt.ylabel(f"Recall_raw@{TOP_K}")
plt.title("Recall vs epoch (raw score)")
plt.legend()
plt.grid(True)
plt.show()

# (2) debiased スコア(u·v - log p_j)で評価したとき
plt.figure(figsize=(6,4))
plt.plot(ep_base, rec_base_deb, label="baseline")
plt.plot(ep_bias, rec_bias_deb, label="bias-corrected")
plt.plot(ep_pop, rec_pop_deb, label="pop-aware sampling")
plt.plot(ep_lw,   rec_lw_deb,   label="loss weighting")
plt.xlabel("epoch")
plt.ylabel(f"Recall_debiased@{TOP_K}")
plt.title("Recall vs epoch (debiased score)")
plt.legend()
plt.grid(True)
plt.show()

########################################
# 10. 図2: popularity vs embedding norm(baseline vs bias)
########################################

def plot_pop_vs_norm_raw(model, title):
    vocab = tf.constant(movie_vocabulary.get_vocabulary())
    v_raw, _ = model.encode_item_raw(vocab)   # ★生ベクトル
    emb_np = v_raw.numpy()
    norms = np.linalg.norm(emb_np, axis=1)

    plt.figure(figsize=(6,4))
    plt.scatter(np.log1p(item_freq), norms, alpha=0.3, s=10)
    plt.xlabel("log(1 + item frequency)")
    plt.ylabel("raw embedding norm")
    plt.title(title)
    plt.grid(True)
    plt.show()

plot_pop_vs_norm_raw(model_base, "Baseline (raw): popularity vs norm")
plot_pop_vs_norm_raw(model_bias, "Bias-corrected (raw): popularity vs norm")

# 図: popularity bin 別 Recall(最終 epochの値, debiased スコア基準)
labels = ["tail", "mid", "head"]
x = np.arange(len(labels))
width = 0.2

final_base = recbins_base_deb[-1]
final_bias = recbins_bias_deb[-1]
final_pop  = recbins_pop_deb[-1]
final_lw   = recbins_lw_deb[-1]

plt.figure(figsize=(6,4))
plt.bar(x - 1.5*width, final_base, width, label="baseline")
plt.bar(x - 0.5*width, final_bias, width, label="bias-corrected")
plt.bar(x + 0.5*width, final_pop,  width, label="pop-aware sampling")
plt.bar(x + 1.5*width, final_lw,   width, label="loss weighting")
plt.xticks(x, labels)
plt.ylabel(f"Recall@{TOP_K} (debiased)")
plt.title("Recall@{} by popularity bin (final epoch, debiased score)".format(TOP_K))
plt.legend()
plt.grid(True, axis="y")
plt.show()

# 図: popularity bin 別 Recall(最終 epochの値, debiased スコア基準)
labels = ["tail", "mid", "head"]
x = np.arange(len(labels))
width = 0.2

final_base = recbins_base_deb[-1]
final_bias = recbins_bias_deb[-1]
final_pop  = recbins_pop_deb[-1]
final_lw   = recbins_lw_deb[-1]

plt.figure(figsize=(6,4))
plt.bar(x - 1.5*width, final_base, width, label="baseline")
plt.bar(x - 0.5*width, final_bias, width, label="bias-corrected")
plt.bar(x + 0.5*width, final_pop,  width, label="pop-aware sampling")
plt.bar(x + 1.5*width, final_lw,   width, label="loss weighting")
plt.xticks(x, labels)
plt.ylabel(f"Recall@{TOP_K} (debiased)")
plt.title(f"Recall@{TOP_K} by popularity bin (final epoch, debiased score)")
plt.legend()
plt.grid(True, axis="y")
plt.show()