суббота, 1 сентября 2018 г.

TensorFlow: базовая классификация, часть 1

В этом и последующих постах представлено руководство, в котором мы обучим модель нейронной сети классифицировать изображения одежды, такие как кроссовки и майки. Это обзор полной TensorFlow программы с объяснением деталей в ходе изложения.

Это руководство использует tf.keras - высокоуровневое API для построения и тренировки моделей в TensorFlow.

# TensorFlow и tf.keras
import tensorflow as tf
from tensorflow import keras

# Вспомогательные библиотеки
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

1.9.0

Импорт Fashion MNIST набора данных

Это руководство использует Fashion MNIST набор данных, который содержит 70 000 изображений оттенков серого в 10 категориях. Изображения показывают предметы одежды в низком разрешении (28 на 28 пикселей) как на следующей картинке:

Fashion MNIST набор данных

Fashion MNIST предназначен в качестве замены для классического MNIST набора данных - часто используемого для "Hello, World" программ машинного обучения компьютерного зрения. MNIST набор данных содержит изображения рукописных цифр (0, 1, 2, и т.д.) в идентичном формате, что и изображения одежды, которые мы будем использовать в этом руководстве.

Это руководство использует Fashion MNIST за его разнообразие и поэтому это будет немного более сложная задача, чем обычный MNIST. Оба набора данных относительно небольшие и используются для проверки того, что алгоритм работает как предполагалось. Это хорошие стартовые точки для тестирования и отлаживания кода.

Мы будем использовать 60 000 изображений для тренировки сети и 10 000 изображений, чтобы оценить насколько точно сеть обучилась классифицировать изображения. Можно получить доступ к Fashion MNIST напрямую из TensorFlow, просто импортируйте и загрузите данные:

fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz 32768/29515 [=================================] - 0s 5us/step
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz 26427392/26421880 [==============================] - 7s 0us/step
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz 8192/5148 [===============================================] - 0s 0us/step
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz 4423680/4422102 [==============================] - 5s 1us/step

Загрузка набора данных возвращает четыре NumPy массива:

  • train_images и train_labels массивы - тренировочный набор - данные, которые модель использует для обучения.
  • test_images и test_labels - тестовый набор, на котором модель проходит тестирование.

Изображения - 28x28 NumPy массивы со значениями пикселей в диапазоне от 0 до 255. Метки - массив цифр в диапазоне от 0 до 9. Они соответствуют классам одежды, которые представляют изображения:

МеткаКласс
0Футболка/топ (T-shirt/top)
1Брюки (Trousers)
2Пуловер (Pullover)
3Платье (Dress)
4Пальто (Coat)
5Сандалия (Sandal)
6Рубашка (Shirt)
7Кроссовки (Sneaker)
8Сумка (Bag)
9Ботильоны (Ankle boot)

Каждому изображению соответствует только одна из меток. Ввиду того что названия классов не включены в набор данных, сохраним их для дальнейшего использования при создании графиков.

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

Исследуем данные

Исследуем формат набора данных перед тренировкой модели. Присутствует 60 000 изображений в тренировочном наборе, каждое представлено как картинка 28 на 28 пикселей:

train_images.shape

(60000, 28, 28)

Присутствует 60 000 меток в тренировочном наборе:

len(train_labels)

60000

Каждая метка - это цифра между 0 и 9:

train_labels

array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)

Присутствует 10 000 изображений в тестовом наборе. Каждое изображение представлено как 28 на 28 пикселей:

test_images.shape

(10000, 28, 28)

И тестовый набор содержит 10 000 меток изображений:

len(test_labels)

10000

Продолжение руководства читайте в следующем посте.