← Python and ROCm 教室へ戻る

第3章

表どうしを計算したい

行列の掛け算は、AI や GPU の話に何度も出てくる大事な形です。

今日は何をしたい?

NumPy の表どうしを掛けて、入力の shape と出力の shape がどう変わるかを読めるようにします。まずは「数字そのもの」より「表の形」に注目します。

なぜ行列の掛け算が大事なの?(クリックで開く)

AI の学習も推論も、中心にあるのは「大量の数字を行列として掛け算する」操作です。画像の特徴抽出、文章の意味変換、重みの更新——どれも内側では行列積(matmul)が動いています。GPU が AI に使われる理由のひとつは、この行列積を並列で高速に計算できるからです。

確認コード

コード中に @ という記号が出てきます。これは Python で「行列の掛け算(行列積)」を表す記号です。詳しい意味は後の「ここで出てきたPython」で説明します。

import numpy as np

a = np.array([[1, 2, 3],
              [4, 5, 6]], dtype=np.float32)
b = np.array([[7, 8],
              [9, 10],
              [11, 12]], dtype=np.float32)

c = a @ b

print("a.shape:", a.shape)
print("b.shape:", b.shape)
print("c.shape:", c.shape)
print(c)

このコードを chapter03.py として保存して、python chapter03.py と実行します。

a(2, 3)b(3, 2) です。結果の c(2, 2) になります。

shape のルール: (m, n) @ (n, p) → (m, p) — 内側の数字が一致していれば掛け算できます。

a (2, 3) × b (3, 2) = c (2, 2)
内側の数字(a の列数と b の行数)が一致しているので掛け算できます。結果の shape は外側の2つの数字になります。
なぜ内側が揃わないといけないの?(クリックで開く)

行列積では a の1行(3個の数字)と b の1列(3個の数字)を「順番に掛けて全部足す」操作を繰り返します。1行が 3個 なら相手の1列も 3個 でなければ掛け合わせられません。これが「内側の一致」が必要な理由です。

実行するとこう出ます

a.shape: (2, 3) b.shape: (3, 2) c.shape: (2, 2) [[ 58. 64.] [139. 154.]]

最重要は c.shape: (2, 2) です。数字の中身より先に shape を確認しましょう。

数字の検算が気になる方へ(クリックで開く)

58. という値は、a の 1 行目 [1, 2, 3] と b の 1 列目 [7, 9, 11] を掛けて足した結果(1×7 + 2×9 + 3×11 = 58)です。数字の検算はしなくていいですが、仕組みが気になったときの参考にしてください。

どこがROCm?

この「表どうしを掛ける」計算は、PyTorch では torch.matmul() や同じ @ 記号で書けます。そしてその裏では rocBLAS(AMD GPU 向けの行列計算エンジン)が動いています(名前は今は覚えなくて大丈夫です)。

つまり a @ b という1行が、GPU 上では rocBLAS の高速な行列積カーネルに変換されて実行されます。ROCm を使う意味のひとつは、この変換経路が整備されていることです。

この章は CPU 上の NumPy ですが、次の章で PyTorch テンソルに移ると、同じ @ の書き方のまま GPU 計算に接続できます。

ここで出てきたPython

* と @ の違い早見表(クリックで開く)
a * b(要素ごと) a @ b(行列積)
計算内容 同じ位置の要素を掛ける 行と列を掛けて足す
shape の条件 同じ shape が必要 内側が一致すればOK
結果の shape 入力と同じ 外側が残る

うまくいかなかったら

結果パターン別の次の一手(クリックで開く)
  • パターンA: c.shape: (2, 2) と出る → この章はクリア。次章へ進んでOK。
  • パターンB: ValueError: matmul: Input operand 1 has a mismatch... のようなエラーが出る → 内側の次元が不一致です。a.shapeb.shape を print して「a の列数 = b の行数」になっているか確認。
  • パターンC: ModuleNotFoundError: numpy → numpy未導入。pip install numpy 後に再実行。

よくあるつまずき

1分演習

次の順で試しましょう。1) 先に c.shape を予想する。2) 実行して答え合わせ。3) 失敗したら内側次元を見直す。

import numpy as np

a = np.array([[1, 2],
              [3, 4]], dtype=np.float32)   # (2, 2)
b = np.array([[5],
              [6]], dtype=np.float32)      # (2, 1)
c = a @ b
print("c.shape:", c.shape)

予想できたら実行して確かめましょう。
— 答えは実行結果で確認してください。ここが迷わず読めれば、第3章の目標達成です。