← Python and ROCm 教室へ戻る

第12章

ログとshapeを読めるようになりたい

止まったときに、すぐ「どこで shape が合っていないか」を見つける力を付ける章です。

今日は何をしたい?

ここまでいくつかの部品を動かしてきました。実際にはエラーで止まることもよくあります。エラーが出たときに、ただ怖がるのではなく、shape と期待していた大きさを見比べて原因を追えるようにします。ROCm の前に、まず shape のずれで止まることはとても多いです。

エラーに出会ったら、この 3 手順で見ます

  1. エラーの最後の行を読む — 何が合わなかったかが書いてある
  2. shape / device / dtype のどれが原因かを切り分ける
  3. 原因の数字を print して、期待値と比べる
  • shape — 入力の次元数や大きさが、層の期待と合っているか?
  • device — テンソルとモデルが同じ場所(CPU / GPU)にいるか?
  • dtype — Float と Long など、型が混ざっていないか?

最小コード — 壊れた例

import torch
import torch.nn as nn

x = torch.randn(2, 3)
layer = nn.Linear(4, 2)

print("x shape:", x.shape)
print("expected in_features:", layer.in_features)

try:
    y = layer(x)
except RuntimeError as e:
    print(type(e).__name__)
    print(e)

入力の最後の大きさ 3 と、Linear(4, 2) が期待する 4 が合っていません。

エラー出力はこうなります。注目すべき場所にマークを付けました。

x shape: torch.Size([2, 3]) expected in_features: 4 RuntimeError mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)

最後の行の 2x34x2 が、shape のずれそのものです。入力の列数 (3) と重みの行数 (4) が一致しないと掛け算できません。

直した例

import torch
import torch.nn as nn

x = torch.randn(2, 3)
layer = nn.Linear(3, 2)  # 4 → 3 に修正

print("x shape:", x.shape)
print("expected in_features:", layer.in_features)

y = layer(x)
print("y shape:", y.shape)
x shape: torch.Size([2, 3]) expected in_features: 3 y shape: torch.Size([2, 2])
エラーが出たら 最後の行を読む → shape / device / dtype のどれかを確認 — この 2 ステップだけで見通しがつきます。

Linear(3, 2) にしたので、入力の最後の 3 と一致して通るようになりました。

実行コマンド

ファイルを chapter12.py という名前で保存したら、ターミナルで次のコマンドを実行します。

python chapter12.py

どこがROCm?

GPU で止まったように見えても、原因が shape mismatch ということはよくあります。ROCm の調子を疑う前に、まず入力 shape、期待 shape、device を見る習慣が役立ちます。

特に device の問題は、CPU テンソルと GPU モデルを混ぜると起きます。エラーに Expected all tensors to be on the same device と出たら、device チェックが先です。

ここで出てきたPython

エラーメッセージの読み方のコツ(クリックで開く)
  • エラーメッセージの型と本文を分けて表示すると、原因が見えやすくなります。
  • 最後の行にある数字を読むのがコツです。
  • shape が合っているのに通らないときは dtype を疑います(「Float なのに Long が来た」のような形)。

よくあるつまずき

1分演習

x = torch.randn(2, 5) に変えて(列数を 3 から 5 に変更)、エラーを出してみましょう。エラーメッセージの最後の行を読んで、どの数字がずれているかを読んでみます。次に nn.Linear(5, 2) に直して、止まらず進むかを確かめます。