第11章
Attentionをのぞきたい
難しく見えやすい Attention も、まずは shape と流れから入ると追いやすくなります。
今日は何をしたい?
第10章は画像向きの部品(Conv2d)でしたが、この章は文章や系列データ向きの部品です。Attention を全部理解するより先に、「入力が入って、重み付きの出力が返る」流れを形で見ます。ここでは Q、K、V の理屈を詰め込みすぎず、まずは 1 回動かしてみます。
データの流れはこうです。
- token
- 入力の「1 個ぶん」。文なら単語、系列なら 1 ステップ。
- embed_dim
- 各 token を表す数字の個数。ここでは 4。
- weights
- 「どの token をどれくらい見たか」を表すスコア。
確認コード
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
attn = nn.MultiheadAttention(
embed_dim=4,
num_heads=1,
batch_first=True, # shape を (バッチ, token数, embed_dim) の順にする設定
).to(device)
x = torch.randn(1, 3, 4, device=device)
y, weights = attn(x, x, x)
print("x shape:", x.shape)
print("y shape:", y.shape)
print("weights shape:", weights.shape)batch_first=True を付けると shape が(バッチ, token数, embed_dim)の順になります。付けないとバッチが2番目に来るので、ここでは付けておきます。
3 つの shape を並べると、この対応が見えます。
| 変数 | shape | 読み方 |
|---|---|---|
| x | (1, 3, 4) | バッチ 1、token 3 個、各 4 次元 |
| y | (1, 3, 4) | 入力と同じ形で返る |
| weights | (1, 3, 3) | token 3 × token 3 の注目スコア |
weights が (3, 3) なのは、3 個の token が「お互いをどれくらい見ているか」を表しているからです。3×3 の各行が「その token が他の token をどれだけ注目しているか」のスコアに対応します。
実行コマンド
ファイルを chapter11.py という名前で保存したら、ターミナルで次のコマンドを実行します。
python chapter11.pyQ / K / V ってなに?
attn(x, x, x) の 3 つの引数は、順に Query(何を探したいか)、Key(何を手がかりにするか)、Value(何を返すか)です。
ここでは全部同じ x を入れています。これを Self-Attention と呼び、「自分の中で、どの部分に注目するか」を計算しています。最初は「同じものを 3 回渡して注目度を計算する」と読めれば十分です。
どこがROCm?
Attention の中では、行列計算や softmax のような処理がいくつも重なります。これらは GPU 上で並列に走るため、ROCm の得意な場面です。「複数の重い計算が GPU で同時に動いている」というイメージで読むと整理しやすいです。
ここで出てきたPython
- y, weights = ... — 2 つの戻り値を受け取る
- キーワード引数で設定を明示
よくあるつまずき
- embed_dim=4 なら、入力の最後の大きさも 4 に合わせます。ここがずれると shape エラーになります。
- token は「文の中の小さな単位」くらいの感覚で大丈夫です。
- 最初から数式を全部追うより、x.shape、y.shape、weights.shape の 3 つを見るほうが入りやすいです。
- Q / K / V が分からなくても、「同じ x を 3 回渡すと Self-Attention になる」と覚えれば先に進めます。
1分演習
x = torch.randn(1, 4, 4, ...) に変えて、token の数を 3 から 4 にしたとき、どの shape が変わるかを見てみましょう。