【PyTorch】サンプル⑥ 〜 nn パッケージ 〜


1. 目的
2. 前準備
3. nn パッケージ
4. PyTorchのインポート
5. 使用するデータ
6. ニューラルネットワークのモデルを定義
7. 損失(loss)の定義
8. 学習パラメータ
9. モデルへのデータ入出力 (Forward pass)
10. 損失(loss)の計算
11. 勾配の初期化
12. 勾配の計算
13. パラメータ(Weight)の更新
14. 実行
14.1. 6_nn_package.py
15. 終わりに


1. 目的

  • PyTorch: nnを参考にPyTorchのnnパッケージを扱う。
  • nnパッケージの便利さを感じる。

2. 前準備

PyTorchのインストールはこちらから。

初めて、Google Colaboratoryを使いたい方は、こちらをご覧ください。

3. nn パッケージ

nnは、ニューラルネットワークの構築に用いる。

PyTorchの自動微分autogradによって、計算グラフやパラメータの勾配を簡単に計算することができます。ですが、自動微分だけで複雑なニューラルネットワークを定義するのは困難です。そこで活躍するのがnnパッケージです。

nnパッケージを用いて、ニューラルネットワークを一つのモジュールとして定義することができます。

4. PyTorchのインポート

import torch

5. 使用するデータ

バッチサイズNを64、入力の次元D_inを1000、隠れ層の次元Hを100、出力の次元D_outを10とします。

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

入力(x)と予測したい(y)を乱数で定義します。

# Create random input and output data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

6. ニューラルネットワークのモデルを定義

ニューラルネットワークのモデルをnnパッケージを用いて定義します。

定義の仕方は、大きく2つありますがここでは一番簡単なtorch.nn.Sequentialを使います。

作り方は簡単で、任意の層を積み重ねていくだけです。この例では、input > Linear(線型結合) > ReLU(活性化関数) > Linear(線型結合) > outputの順に層が積み重なっています。

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

7. 損失(loss)の定義

二乗誤差もnnパッケージを用いて計算することができます。

reductionのデフォルトはmeanですので、何も指定しなければtorch.nn.MSELossは平均二乗誤差を返します。

reduction=sumとした場合は、累積二乗誤差を算出します。

loss_fn = torch.nn.MSELoss(reduction='sum')

(参考) nnパッケージを使う前は以下のように記述していた。

    loss = (y_pred - y).pow(2).sum()

8. 学習パラメータ

学習率を1e-4として、学習回数を500回とします。

learning_rate = 1e-4
for t in range(500):

9. モデルへのデータ入出力 (Forward pass)

定義したニューラルネットワークモデルへデータxを入力し、予測値y_predを取得します。

y_pred = model(x)

(参考) nnパッケージを使う前は以下のように記述していた。

    y_pred = x.mm(w1).clamp(min=0).mm(w2)

10. 損失(loss)の計算

定義した損失関数で予測値y_predと真値yとの間の損失を計算します。

    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

11. 勾配の初期化

逆伝播(backward)させる前に、モデルのパラメータが持つ勾配を0(ゼロ)で初期化します。

    model.zero_grad()

12. 勾配の計算

backwardメソッドでモデルパラメータ(Weight)の勾配を算出します。

    loss.backward()

13. パラメータ(Weight)の更新

確率勾配降下法(SGD: stochastic gradient descent)で、Weightを更新する。

    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate param.grad

(参考) nnパッケージを使う前は以下のように記述していた。

    with torch.no_grad():
        w1 -= learning_rate w1.grad
        w2 -= learning_rate w2.grad

14. 実行

以下のコードを6_nn_package.pyとして保存します。

14.1. 6_nn_package.py

import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. 
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. 
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    loss.backward()

    # Update the weights using gradient descent.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate param.grad

保存ができたら実行しましょう。

左の数字が学習回数、右の数値がパーセプトロンの推定値と実際の答えと二乗誤差です。

学習を重ねるごとに、二乗誤差が小さくなることがわかります。

$ python3 6_nn_package.py 
99 2.5003600120544434
199 0.06977272033691406
299 0.003307548351585865
399 0.00018405442824587226
499 1.1299152902211063e-05

15. 終わりに

多層のニューラルネットワークには、膨大な量のパラメータが存在しています。

nnパッケージを用いる前は、各パラメータごとに勾配計算やパラメータの更新などをしていましたが、それでは記述が困難です。

nnパッケージを用いることで、楽に勾配計算やパラメータの更新等が実行できることが感じてもらえたら嬉しいです(^^)!。

Print Friendly, PDF & Email

コメントを残す

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください