← Python and ROCm 教室へ戻る

第11章

Attentionをのぞきたい

難しく見えやすい Attention も、まずは shape と流れから入ると追いやすくなります。

今日は何をしたい?

第10章は画像向きの部品(Conv2d)でしたが、この章は文章や系列データ向きの部品です。Attention を全部理解するより先に、「入力が入って、重み付きの出力が返る」流れを形で見ます。ここでは QKV の理屈を詰め込みすぎず、まずは 1 回動かしてみます。

データの流れはこうです。

x (入力) MultiheadAttention y (出力) + weights (注目度)
token
入力の「1 個ぶん」。文なら単語、系列なら 1 ステップ。
embed_dim
各 token を表す数字の個数。ここでは 4。
weights
「どの token をどれくらい見たか」を表すスコア。

確認コード

import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

attn = nn.MultiheadAttention(
    embed_dim=4,
    num_heads=1,
    batch_first=True,  # shape を (バッチ, token数, embed_dim) の順にする設定
).to(device)

x = torch.randn(1, 3, 4, device=device)
y, weights = attn(x, x, x)

print("x      shape:", x.shape)
print("y      shape:", y.shape)
print("weights shape:", weights.shape)

batch_first=True を付けると shape が(バッチ, token数, embed_dim)の順になります。付けないとバッチが2番目に来るので、ここでは付けておきます。

x shape: torch.Size([1, 3, 4]) y shape: torch.Size([1, 3, 4]) weights shape: torch.Size([1, 3, 3])

3 つの shape を並べると、この対応が見えます。

変数shape読み方
x (1, 3, 4) バッチ 1、token 3 個、各 4 次元
y (1, 3, 4) 入力と同じ形で返る
weights (1, 3, 3) token 3 × token 3 の注目スコア

weights が (3, 3) なのは、3 個の token が「お互いをどれくらい見ているか」を表しているからです。3×3 の各行が「その token が他の token をどれだけ注目しているか」のスコアに対応します。

x.shapey.shapeweights.shape の 3 つが読めればこの章は OK です。数式を全部追う必要はありません。

実行コマンド

ファイルを chapter11.py という名前で保存したら、ターミナルで次のコマンドを実行します。

python chapter11.py

Q / K / V ってなに?

attn(x, x, x) の 3 つの引数は、順に Query(何を探したいか)、Key(何を手がかりにするか)、Value(何を返すか)です。

ここでは全部同じ x を入れています。これを Self-Attention と呼び、「自分の中で、どの部分に注目するか」を計算しています。最初は「同じものを 3 回渡して注目度を計算する」と読めれば十分です。

どこがROCm?

Attention の中では、行列計算や softmax のような処理がいくつも重なります。これらは GPU 上で並列に走るため、ROCm の得意な場面です。「複数の重い計算が GPU で同時に動いている」というイメージで読むと整理しやすいです。

ここで出てきたPython

よくあるつまずき

1分演習

x = torch.randn(1, 4, 4, ...) に変えて、token の数を 3 から 4 にしたとき、どの shape が変わるかを見てみましょう。