【PyTorch】チュートリアル(日本語版 )④ 〜TRAINING A CLASSIFIER(画像分類)〜

目的

PyTorchのチュートリアルTraining a Classifierを参考にPyTorchで画像分類について学ぶ。

具体的には、

  • ニューラルネットワークの構築
  • lossの計算
  • ネットワークの重みの更新
    について学習する。

PyTorchで扱うデータについて

一般的に、画像、テキスト、音声、動画を扱う際には、Pythonの標準ライブラリを使ってデータをNumPy配列として取り込むことができます。
そして、そのNumPy配列をtorch.*Tensorでテンソル配列に変換することができます。

画像の場合は、PillowやOpenCVのようなパッケージがよく使われます。
音声の場合は、scipyやlibrosaが便利です。
テキストの場合は、Pythonの標準ライブラリやNLTK、SpaCyを使います。

特に、PyTorchで画像の場合は、torchvisionパッケージを使います。
torchvisionには、Imagenet、CIFAR10、MNISTといった有名なデータセットがあります。

このチュートリアルでは、CIFAR10を使います。
CIFAR10データセットには、飛行機、自動車、鳥、猫、鹿、犬、蛙、馬、船、トラックといった10のクラスがあります。
画像はカラー(3-channel) で32×32 pixelsです。

(出典先: CIFAR-10 and CIFAR-100 datasets)

画像分類器の作り方

画像分類器は、以下の手順で創ることができます。

  1. torchvisionを使ってCIFAR10の読み込みと標準化
  2. 畳み込みニューラルネットワーク(CNN)の構築
  3. loss関数(損失関数)の定義
  4. Trainデータを使ってCNNに学習
  5. 学習済みのCNNを使ってTestデータを使って画像を分類

1. データの読み込みと標準化

必要なパッケージをimportします。

import torch
import torchvision
import torchvision.transforms as transforms

torchvisionのデータセットはrange[0,1]のPILImageなので、range[-1,1]に標準化したテンソルに変換する。
実際には、

  • transforms.Composeで読み込んだデータの前処理関数の構成
  • transforms.ToTensor()でテンソルに変換
  • transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))で標準化
    します。
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))の、1つめの引数(タプル)がRGBそれぞれの平均、2つめの引数(タプル)が標準偏差であり、これらの平均と標準偏差をつかって標準化します。
Y = \dfrac{X-\mu}{\sigma}
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

TrainデータとTestデータをそれぞれ読み込んでいきます。
データの読み込みはtorchvision.datasets.CIFAR10を使います。
この際、download=Trueであればrootに画像を保存します。
この場合は、./data保存されることになります。

CIFAR10データセットは6万枚の画像の内、5万はTrainデータ、1万はTestデータというようにわけられています。そのため、trainsetではtrain=True
testsetではtrain=FalseとすることでTrainデータとTestデータを簡単に分けることができます。

既にテンソルに変換・標準化されたデータ(transform)を読み込みたいので、transform=transformとします。

torch.utils.data.DataLoaderでは、バッチサイズbatch_sizeやデータシャッフルの有無shuffle、どれだけ並列処理するかnum_workersを設定します。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

飛行機、自動車、鳥、猫、鹿、犬、蛙、馬、船、トラックをタプルとしてclassesに格納します。

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

それでは、Trainデータを見てみましょう。
まずは、画像の可視化に必要なパッケージのimportです。

import matplotlib.pyplot as plt
import numpy as np

標準化した画像をもとに戻し、テンソル配列からNumPy配列にします。
また、画像の各次元が(チャネル数(RGB), 高さ, 横幅)と並んでいますが、plt.showは (高さ, 横幅, チャネル数)で並んでいる必要があるのでnp.transposeで並び替えます。

def imshow(img):
    img = img / 2 + 0.5   # 標準化を戻す
    npimg = img.numpy()   # NumPy配列に変換
    plt.imshow(np.transpose(npimg, (1, 2, 0)))   # (高さ, 横幅, チャネル数)となるよう整形
    plt.show()   #画像の表示

iterを使うことで各バッチごとにデータを読み出すことができます。
dataiter.next()を呼ぶことで次々とバッチごとに画像とそのラベルを返してくれます。

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

先程自作で定義したimshowを使って画像を可視化します。
引数のtorchvision.utils.make_gridによって複数の画像を横並びにしてくれます。
各画像とラベルを確認してみます。

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

cat car dog dog

以下の画像が、表示される。

2. 畳み込みニューラルネットワーク(CNN)の構築

CNNを構築していきます。
ニューラルネットワークの定義については、こちらをご覧ください。
こちらと同じニューラルネットワークを使用します。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

3. loss関数と最適化

このチュートリアルでは、損失関数として交差エントロピー(Cross-Entropy)を、最適化関数として、Momentum SDGを使います。

損失関数の設定は、nn.CrossEntropyLoss()でしています。
また、最適化関数の設定は、optim.SGD(net.parameters(), lr=0.001, momentum=0.9)でしています。
net.parameters()は先程定義したCNNのパラメータ(重み)で、これがoptim.SGDで更新されることになります。
lr=0.001は学習率が0.001であることを示しています。
momentumでは、SGDのmomentumを設定することができます。デフォルトはmomentum=0です。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()   # 損失関数を交差エントロピーに設定
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)   # 最適化関数をSGDにしてmomentumを指定

4. CNNの学習

次に、構築したCNNの学習をします。
CNNは学習の中で、ネットワーク内のパラメータを更新し、正しい答えが導き出せるようにパラメータを最適化します。
これが、CNNの学習でやっていることです。

今回は、簡単のため、学習回数epochを2回にします。

optimizerは計算した勾配(grad)を記録し蓄積し続けるので、一度optimizer.zero_grad()で勾配初期化します。

outputs = net(inputs)でTrainデータをCNNに流し込み、CNNからの出力を取得します。

loss = criterion(outputs, labels)では、CNNからの出力outputsと実際の答えlabelsとの間の交差エントロピーを計算します。

loss.backward()では、目的関数である交差エントロピーに含まれるパラメータの微分係数を計算します。

optimizer.step()では、loss.backward()で計算した勾配をもとに、CNNのパラメータを更新します。

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

[1, 2000] loss: 2.219
[1, 4000] loss: 1.815
[1, 6000] loss: 1.649
[1, 8000] loss: 1.561
[1, 10000] loss: 1.504
[1, 12000] loss: 1.456
[2, 2000] loss: 1.385
[2, 4000] loss: 1.367
[2, 6000] loss: 1.339
[2, 8000] loss: 1.341
[2, 10000] loss: 1.295
[2, 12000] loss: 1.280
Finished Training

学習したモデル(CNN)をcifar_net.pthという名前で保存します。
モデルのパラメータを取り出すには、net.state_dict()を使います。
モデルの保存には、torch.saveでできます。

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

5. 学習済みCNNのテスト

CNNの学習回数を2にして学習させました。次に、このCNNの性能をTestデータを使って評価したいと思います。

まず、Testデータを表示します。

dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

GroundTruth: cat ship ship plane

以下の画像が、表示される。

次に、練習のために保存した学習済みのモデル(cifar_net.pth)を読み込みます。
モデルの読み込みには、net.load_state_dictを使います。

net = Net()
net.load_state_dict(torch.load(PATH))

読み込んだモデルに画像データを入力します。

outputs = net(images)

CNNは、10クラス分の出力を出します。
これら10個の出力の中で最も大きな値をCNNが推定したクラスだと考えます。
outputsの1つめの出力を見てみると、最大値は2.8439で、その値は10個の要素の内0から数えて3番目であることがわかります。
最初の方で定義したclassesは、classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')なので、CNNが推定する結果はcatになります。

print("outputs:\n{}".format(outputs))

_, predicted = torch.max(outputs, 1)
print("max value:{}".format(_))
print("Predicted:{}".format(predicted))

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))

outputs:
tensor([[-1.2680, -1.6149, 0.9799, 2.8439, 0.3981, 2.5890, -0.2455, 0.5544,
-1.3237, -1.6653],
[ 3.6198, 5.9820, -2.1858, -3.1164, -4.8510, -4.3395, -3.1334, -4.9847,
6.9285, 4.8581],
[ 0.9089, 2.3129, -0.1770, -0.7791, -1.8234, -1.6938, -1.0573, -1.6433,
1.7940, 1.7207],
[ 3.2211, -0.6499, 0.9810, -1.7527, -0.6010, -2.6471, -1.3576, -1.9702,
4.0528, 0.5546]], grad_fn=)

max value:tensor([2.8439, 6.9285, 2.3129, 4.0528], grad_fn=)
Predicted:tensor([3, 8, 1, 8])

Predicted: cat plane ship ship

今は、1バッチ分のTestデータしか評価しませんでしが、すべてのTestデータをまとめて評価したい場合には、このように記述します。

後ほど、正確さ(Accuracy)を計算できるように、Testデータ数totalと正解したデータ数correcttotal += labels.size(0)correct += (predicted == labels).sum().item()の部分でカウントします。

with torch.no_grad():でCNNによる画像分類の前に勾配を初期化しています。

for data in testloader:で、testloaderに格納されているバッチデータを1づつ取り出してdataに格納します。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 50 %

先程は、10クラス全体でのAccuracyを計算しましたが、次にクラスごとのAccuracyを計算します。

クラスごとのデータ数class_totalと正解したデータ数class_correctclass_correct[label] += c[i].item()class_total[label] += 1の部分で計算します。

c = (predicted == labels).squeeze()では、CNNの推定クラスと実際のクラスが同じであるかをbool型のテンソル返し、さらにテンソルの次元の中で,1のものを消します.

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane : 43 %
Accuracy of car : 55 %
Accuracy of bird : 51 %
Accuracy of cat : 42 %
Accuracy of deer : 37 %
Accuracy of dog : 51 %
Accuracy of frog : 58 %
Accuracy of horse : 66 %
Accuracy of ship : 65 %
Accuracy of truck : 76 %

GPUを使った学習

先程のニューラルネットワークの学習では、CPUを用いて学習しましたが、GPUを用いてニューラルネットワークを学習させることができます。

まずは、CUDAがお使いのPCで認識されているか確認してみましょう。
CUDAが認識されている場合は、cuda:と表示され、認識されていない場合は、cpuと表示されます。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0

ニューラルネットワークnetのパラメータをCUDAテンソルの変換します。

net.to.(device)

ネットワークパラメータだけではなく、入力するデータもCUDAテンソルの変換する必要があります。
したがって、各バッチごとにdata.to(device)を使ってCUDAテンソルに変換します。

inputs, labels = data[0].to(device), data[1].to(device)

コメントを残す

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください