{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "basic_classification.ipynb", "version": "0.3.2", "provenance": [], "private_outputs": true, "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "cells": [ { "metadata": { "colab_type": "text", "id": "MhoQ0WE77laV" }, "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "metadata": { "cellView": "form", "colab_type": "code", "id": "_ckMIh7O7s6D", "colab": {} }, "cell_type": "code", "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ], "execution_count": 0, "outputs": [] }, { "metadata": { "cellView": "form", "colab_type": "code", "id": "vasWnqRgy1H4", "colab": {} }, "cell_type": "code", "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "jYysdyb-CaWM" }, "cell_type": "markdown", "source": [ "# はじめてのニューラルネットワーク:分類問題の初歩" ] }, { "metadata": { "colab_type": "text", "id": "S5Uhzt6vVIB2" }, "cell_type": "markdown", "source": [ "\n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", "
" ] }, { "metadata": { "id": "RSywPQ2n736s", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。" ] }, { "metadata": { "colab_type": "text", "id": "FbVhjPpzn6BM" }, "cell_type": "markdown", "source": [ "このガイドでは、スニーカーやシャツなど、身に着けるものの写真を分類するニューラルネットワークのモデルを訓練します。すべての詳細を理解できなくても問題ありません。TensorFlowの全体を早足で掴むためのもので、詳細についてはあとから見ていくことになります。\n", "\n", "このガイドでは、TensorFlowのモデルを構築し訓練するためのハイレベルのAPIである [tf.keras](https://www.tensorflow.org/guide/keras)を使用します。" ] }, { "metadata": { "colab_type": "code", "id": "dzLKpmZICaWN", "colab": {} }, "cell_type": "code", "source": [ "from __future__ import absolute_import, division, print_function, unicode_literals\n", "\n", "# TensorFlow と tf.keras のインポート\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "# ヘルパーライブラリのインポート\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "print(tf.__version__)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "yR0EdgrLCaWR" }, "cell_type": "markdown", "source": [ "## ファッションMNISTデータセットのロード" ] }, { "metadata": { "colab_type": "text", "id": "DLdCchMdCaWQ" }, "cell_type": "markdown", "source": [ "このガイドでは、[Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist)を使用します。Fashion MNISTには10カテゴリーの白黒画像70,000枚が含まれています。それぞれは下図のような1枚に付き1種類の衣料品が写っている低解像度(28×28ピクセル)の画像です。\n", "\n", "\n", " \n", " \n", "
\n", " \"Fashion\n", "
\n", " Figure 1. Fashion-MNIST samples (by Zalando, MIT License).
 \n", "
\n", "\n", "Fashion MNISTは、画像処理のための機械学習での\"Hello, World\"としてしばしば登場する[MNIST](http://yann.lecun.com/exdb/mnist/) データセットの代替として開発されたものです。MNISTデータセットは手書きの数字(0, 1, 2 など)から構成されており、そのフォーマットはこれから使うFashion MNISTと全く同じです。\n", "\n", "Fashion MNISTを使うのは、目先を変える意味もありますが、普通のMNISTよりも少しだけ手応えがあるからでもあります。どちらのデータセットも比較的小さく、アルゴリズムが期待したとおりに機能するかどうかを確かめるために使われます。プログラムのテストやデバッグのためには、よい出発点になります。\n", "\n", "ここでは、60,000枚の画像を訓練に、10,000枚の画像を、ネットワークが学習した画像分類の正確性を評価するのに使います。TensorFlowを使うと、下記のようにFashion MNISTのデータを簡単にインポートし、ロードすることが出来ます。" ] }, { "metadata": { "colab_type": "code", "id": "7MqDQO0KCaWS", "colab": {} }, "cell_type": "code", "source": [ "fashion_mnist = keras.datasets.fashion_mnist\n", "\n", "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "t9FDsUlxCaWW" }, "cell_type": "markdown", "source": [ "ロードしたデータセットは、NumPy配列になります。\n", "\n", "* `train_images` と `train_labels` の2つの配列は、モデルの訓練に使用される**訓練用データセット**です。\n", "* 訓練されたモデルは、 `test_images` と `test_labels` 配列からなる**テスト用データセット**を使ってテストします。\n", "\n", "画像は28×28のNumPy配列から構成されています。それぞれのピクセルの値は0から255の間の整数です。**ラベル**(label)は、0から9までの整数の配列です。それぞれの数字が下表のように、衣料品の**クラス**(class)に対応しています。\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LabelClass
0T-shirt/top
1Trouser
2Pullover
3Dress
4Coat
5Sandal
6Shirt
7Sneaker
8Bag
9Ankle boot
\n", "\n", "画像はそれぞれ単一のラベルに分類されます。データセットには上記の**クラス名**が含まれていないため、後ほど画像を出力するときのために、クラス名を保存しておきます。" ] }, { "metadata": { "colab_type": "code", "id": "IjnLH5S2CaWx", "colab": {} }, "cell_type": "code", "source": [ "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n", " 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "Brm0b_KACaWX" }, "cell_type": "markdown", "source": [ "## データの観察\n", "\n", "モデルの訓練を行う前に、データセットのフォーマットを見てみましょう。下記のように、訓練用データセットには28×28ピクセルの画像が60,000枚含まれています。" ] }, { "metadata": { "colab_type": "code", "id": "zW5k_xz1CaWX", "colab": {} }, "cell_type": "code", "source": [ "train_images.shape" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "cIAcvQqMCaWf" }, "cell_type": "markdown", "source": [ "同様に、訓練用データセットには60,000個のラベルが含まれます。" ] }, { "metadata": { "colab_type": "code", "id": "TRFYHB2mCaWb", "colab": {} }, "cell_type": "code", "source": [ "len(train_labels)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "YSlYxFuRCaWk" }, "cell_type": "markdown", "source": [ "ラベルはそれぞれ、0から9までの間の整数です。" ] }, { "metadata": { "colab_type": "code", "id": "XKnCTHz4CaWg", "colab": {} }, "cell_type": "code", "source": [ "train_labels" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "TMPI88iZpO2T" }, "cell_type": "markdown", "source": [ "テスト用データセットには、10,000枚の画像が含まれます。画像は28×28ピクセルで構成されています。" ] }, { "metadata": { "colab_type": "code", "id": "2KFnYlcwCaWl", "colab": {} }, "cell_type": "code", "source": [ "test_images.shape" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "rd0A0Iu0CaWq" }, "cell_type": "markdown", "source": [ "テスト用データセットには10,000個のラベルが含まれます。" ] }, { "metadata": { "colab_type": "code", "id": "iJmPr5-ACaWn", "colab": {} }, "cell_type": "code", "source": [ "len(test_labels)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "ES6uQoLKCaWr" }, "cell_type": "markdown", "source": [ "## データの前処理\n", "\n", "ネットワークを訓練する前に、データを前処理する必要があります。最初の画像を調べてみればわかるように、ピクセルの値は0から255の間の数値です。" ] }, { "metadata": { "colab_type": "code", "id": "m4VEw8Ud9Quh", "colab": {} }, "cell_type": "code", "source": [ "plt.figure()\n", "plt.imshow(train_images[0])\n", "plt.colorbar()\n", "plt.gca().grid(False)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "Wz7l27Lz9S1P" }, "cell_type": "markdown", "source": [ "ニューラルネットワークにデータを投入する前に、これらの値を0から1までの範囲にスケールします。そのためには、画素の値を255で割ります。" ] }, { "metadata": { "colab_type": "text", "id": "3jCZdQNNCaWv" }, "cell_type": "markdown", "source": [ "**訓練用データセット**と**テスト用データセット**は、同じように前処理することが重要です。" ] }, { "metadata": { "colab_type": "code", "id": "bW5WzIPlCaWv", "colab": {} }, "cell_type": "code", "source": [ "train_images = train_images / 255.0\n", "\n", "test_images = test_images / 255.0" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "Ee638AlnCaWz" }, "cell_type": "markdown", "source": [ "**訓練用データセット**の最初の25枚の画像を、クラス名付きで表示してみましょう。ネットワークを構築・訓練する前に、データが正しいフォーマットになっていることを確認します。" ] }, { "metadata": { "colab_type": "code", "id": "oZTImqg_CaW1", "colab": {} }, "cell_type": "code", "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.figure(figsize=(10,10))\n", "for i in range(25):\n", " plt.subplot(5,5,i+1)\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.grid(False)\n", " plt.imshow(train_images[i], cmap=plt.cm.binary)\n", " plt.xlabel(class_names[train_labels[i]])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "59veuiEZCaW4" }, "cell_type": "markdown", "source": [ "## モデルの構築\n", "\n", "ニューラルネットワークを構築するには、まずモデルの階層を定義し、その後モデルをコンパイルします。" ] }, { "metadata": { "colab_type": "text", "id": "Gxg1XGm0eOBy" }, "cell_type": "markdown", "source": [ "### 層の設定\n", "\n", "ニューラルネットワークを形作る基本的な構成要素は**層**(layer)です。層は、入力されたデータから「表現」を抽出します。それらの「表現」は、今取り組もうとしている問題に対して、より「意味のある」ものであることが期待されます。\n", "\n", "ディープラーニングモデルのほとんどは、単純な層の積み重ねで構成されています。`tf.keras.layers.Dense` のような層のほとんどには、訓練中に学習されるパラメータが存在します。" ] }, { "metadata": { "colab_type": "code", "id": "9ODch-OFCaW4", "colab": {} }, "cell_type": "code", "source": [ "model = keras.Sequential([\n", " keras.layers.Flatten(input_shape=(28, 28)),\n", " keras.layers.Dense(128, activation=tf.nn.relu),\n", " keras.layers.Dense(10, activation=tf.nn.softmax)\n", "])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "gut8A_7rCaW6" }, "cell_type": "markdown", "source": [ "このネットワークの最初の層は、`tf.keras.layers.Flatten` です。この層は、画像を(28×28ピクセルの)2次元配列から、28×28=784ピクセルの、1次元配列に変換します。この層が、画像の中に積まれているピクセルの行を取り崩し、横に並べると考えてください。この層には学習すべきパラメータはなく、ただデータのフォーマット変換を行うだけです。\n", "\n", "ピクセルが1次元化されたあと、ネットワークは2つの `tf.keras.layers.Dense` 層となります。これらの層は、密結合あるいは全結合されたニューロンの層となります。最初の `Dense` 層には、128個のノード(あるはニューロン)があります。最後の層でもある2番めの層は、10ノードの**softmax**層です。この層は、合計が1になる10個の確率の配列を返します。それぞれのノードは、今見ている画像が10個のクラスのひとつひとつに属する確率を出力します。\n", "\n", "### モデルのコンパイル\n", "\n", "モデルが訓練できるようになるには、いくつかの設定を追加する必要があります。それらの設定は、モデルの**コンパイル**(compile)時に追加されます。\n", "\n", "* **損失関数**(loss function) —訓練中にモデルがどれくらい正確かを測定します。この関数の値を最小化することにより、訓練中のモデルを正しい方向に向かわせようというわけです。\n", "* **オプティマイザ**(optimizer)—モデルが見ているデータと、損失関数の値から、どのようにモデルを更新するかを決定します。\n", "* **メトリクス**(metrics) —訓練とテストのステップを監視するのに使用します。下記の例では*accuracy* (正解率)、つまり、画像が正しく分類された比率を使用しています。" ] }, { "metadata": { "colab_type": "code", "id": "Lhan11blCaW7", "colab": {} }, "cell_type": "code", "source": [ "model.compile(optimizer=tf.keras.optimizers.Adam(),\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "qKF6uW-BCaW-" }, "cell_type": "markdown", "source": [ "## モデルの訓練\n", "\n", "ニューラルネットワークの訓練には次のようなステップが必要です。\n", "\n", "1. モデルに訓練用データを投入します—この例では `train_images` と `train_labels` の2つの配列です。\n", "2. モデルは、画像とラベルの対応関係を学習します。\n", "3. モデルにテスト用データセットの予測(分類)を行わせます—この例では `test_images` 配列です。その後、予測結果と `test_labels` 配列を照合します。\n", "\n", "訓練を開始するには、`model.fit` メソッドを呼び出します。モデルを訓練用データに \"fit\"(適合)させるという意味です。" ] }, { "metadata": { "colab_type": "code", "id": "xvwvpA64CaW_", "colab": {} }, "cell_type": "code", "source": [ "model.fit(train_images, train_labels, epochs=5)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "W3ZVOhugCaXA" }, "cell_type": "markdown", "source": [ "モデルの訓練の進行とともに、損失値と正解率が表示されます。このモデルの場合、訓練用データでは0.88(すなわち88%)の正解率に達します。" ] }, { "metadata": { "colab_type": "text", "id": "oEw4bZgGCaXB" }, "cell_type": "markdown", "source": [ "## 正解率の評価\n", "\n", "次に、テスト用データセットに対するモデルの性能を比較します。" ] }, { "metadata": { "colab_type": "code", "id": "VflXLEeECaXC", "colab": {} }, "cell_type": "code", "source": [ "test_loss, test_acc = model.evaluate(test_images, test_labels)\n", "\n", "print('Test accuracy:', test_acc)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "yWfgsmVXCaXG" }, "cell_type": "markdown", "source": [ "ご覧の通り、テスト用データセットでの正解率は、訓練用データセットでの正解率よりも少し低くなります。この訓練時の正解率とテスト時の正解率の差は、**過学習**(over fitting)の一例です。過学習とは、新しいデータに対する機械学習モデルの性能が、訓練時と比較して低下する現象です。" ] }, { "metadata": { "colab_type": "text", "id": "xsoS7CPDCaXH" }, "cell_type": "markdown", "source": [ "## 予測する\n", "\n", "モデルの訓練が終わったら、そのモデルを使って画像の分類予測を行うことが出来ます。" ] }, { "metadata": { "colab_type": "code", "id": "Gl91RPhdCaXI", "colab": {} }, "cell_type": "code", "source": [ "predictions = model.predict(test_images)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "x9Kk1voUCaXJ" }, "cell_type": "markdown", "source": [ "これは、モデルがテスト用データセットの画像のひとつひとつを分類予測した結果です。最初の予測を見てみましょう。" ] }, { "metadata": { "colab_type": "code", "id": "3DmJEUinCaXK", "colab": {} }, "cell_type": "code", "source": [ "predictions[0]" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "-hw1hgeSCaXN" }, "cell_type": "markdown", "source": [ "予測結果は、10個の数字の配列です。これは、その画像が10の衣料品の種類のそれぞれに該当するかの「確信度」を表しています。どのラベルが一番確信度が高いかを見てみましょう。" ] }, { "metadata": { "colab_type": "code", "id": "qsqenuPnCaXO", "colab": {} }, "cell_type": "code", "source": [ "np.argmax(predictions[0])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "E51yS7iCCaXO" }, "cell_type": "markdown", "source": [ "というわけで、このモデルは、この画像が、アンクルブーツ、`class_names[9]` である可能性が最も高いと判断したことになります。これが正しいかどうか、テスト用ラベルを見てみましょう。" ] }, { "metadata": { "colab_type": "code", "id": "Sd7Pgsu6CaXP", "colab": {} }, "cell_type": "code", "source": [ "test_labels[0]" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "kgdvGD52CaXR" }, "cell_type": "markdown", "source": [ "10チャンネルすべてをグラフ化してみることができます。" ] }, { "metadata": { "colab_type": "code", "id": "VsRq6uZiG7eT", "colab": {} }, "cell_type": "code", "source": [ "def plot_image(i, predictions_array, true_label, img):\n", " predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]\n", " plt.grid(False)\n", " plt.xticks([])\n", " plt.yticks([])\n", "\n", " plt.imshow(img, cmap=plt.cm.binary)\n", "\n", " predicted_label = np.argmax(predictions_array)\n", " if predicted_label == true_label:\n", " color = 'blue'\n", " else:\n", " color = 'red'\n", "\n", " plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n", " 100*np.max(predictions_array),\n", " class_names[true_label]),\n", " color=color)\n", "\n", "def plot_value_array(i, predictions_array, true_label):\n", " predictions_array, true_label = predictions_array[i], true_label[i]\n", " plt.grid(False)\n", " plt.xticks([])\n", " plt.yticks([])\n", " thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n", " plt.ylim([0, 1])\n", " predicted_label = np.argmax(predictions_array)\n", "\n", " thisplot[predicted_label].set_color('red')\n", " thisplot[true_label].set_color('blue')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "aZ_jDyLZG7eW" }, "cell_type": "markdown", "source": [ "0番目の画像と、予測、予測配列を見てみましょう。" ] }, { "metadata": { "colab_type": "code", "id": "UH_jgCxEG7eW", "colab": {} }, "cell_type": "code", "source": [ "i = 0\n", "plt.figure(figsize=(6,3))\n", "plt.subplot(1,2,1)\n", "plot_image(i, predictions, test_labels, test_images)\n", "plt.subplot(1,2,2)\n", "plot_value_array(i, predictions, test_labels)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "code", "id": "5_7K0ZL7G7eY", "colab": {} }, "cell_type": "code", "source": [ "i = 12\n", "plt.figure(figsize=(6,3))\n", "plt.subplot(1,2,1)\n", "plot_image(i, predictions, test_labels, test_images)\n", "plt.subplot(1,2,2)\n", "plot_value_array(i, predictions, test_labels)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "Lduh0pbfG7eb" }, "cell_type": "markdown", "source": [ "予測の中のいくつかの画像を、予測値とともに表示してみましょう。正しい予測は青で、誤っている予測は赤でラベルを表示します。数字は予測したラベルのパーセント(100分率)を示します。自信があるように見えても間違っていることがあることに注意してください。" ] }, { "metadata": { "colab_type": "code", "id": "YGBDAiziCaXR", "colab": {} }, "cell_type": "code", "source": [ "# X個のテスト画像、予測されたラベル、正解ラベルを表示します。\n", "# 正しい予測は青で、間違った予測は赤で表示しています。\n", "num_rows = 5\n", "num_cols = 3\n", "num_images = num_rows*num_cols\n", "plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n", "for i in range(num_images):\n", " plt.subplot(num_rows, 2*num_cols, 2*i+1)\n", " plot_image(i, predictions, test_labels, test_images)\n", " plt.subplot(num_rows, 2*num_cols, 2*i+2)\n", " plot_value_array(i, predictions, test_labels)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "R32zteKHCaXT" }, "cell_type": "markdown", "source": [ "最後に、訓練済みモデルを使って1枚の画像に対する予測を行います。" ] }, { "metadata": { "colab_type": "code", "id": "yRJ7JU7JCaXT", "colab": {} }, "cell_type": "code", "source": [ "# テスト用データセットから画像を1枚取り出す\n", "img = test_images[0]\n", "\n", "print(img.shape)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "vz3bVp21CaXV" }, "cell_type": "markdown", "source": [ "`tf.keras` モデルは、サンプルの中の**バッチ**(batch)あるいは「集まり」について予測を行うように作られています。そのため、1枚の画像を使う場合でも、リスト化する必要があります。" ] }, { "metadata": { "colab_type": "code", "id": "lDFh5yF_CaXW", "colab": {} }, "cell_type": "code", "source": [ "# 画像を1枚だけのバッチのメンバーにする\n", "img = (np.expand_dims(img,0))\n", "\n", "print(img.shape)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "EQ5wLTkcCaXY" }, "cell_type": "markdown", "source": [ "そして、予測を行います。" ] }, { "metadata": { "colab_type": "code", "id": "o_rzNSdrCaXY", "scrolled": true, "colab": {} }, "cell_type": "code", "source": [ "predictions_single = model.predict(img)\n", "\n", "print(predictions_single)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "code", "id": "6o3nwO-KG7ex", "colab": {} }, "cell_type": "code", "source": [ "plot_value_array(0, predictions_single, test_labels)\n", "_ = plt.xticks(range(10), class_names, rotation=45)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "cU1Y2OAMCaXb" }, "cell_type": "markdown", "source": [ "`model.predict` メソッドの戻り値は、リストのリストです。リストの要素のそれぞれが、バッチの中の画像に対応します。バッチの中から、(といってもバッチの中身は1つだけですが)予測を取り出します。" ] }, { "metadata": { "colab_type": "code", "id": "2tRmdq_8CaXb", "colab": {} }, "cell_type": "code", "source": [ "prediction = predictions[0]\n", "\n", "np.argmax(prediction)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "YFc2HbEVCaXd" }, "cell_type": "markdown", "source": [ "というわけで、モデルは9というラベルを予測しました。" ] } ] }