В этом и последующих постах представлено руководство, в котором мы обучим модель нейронной сети классифицировать изображения одежды, такие как кроссовки и майки. Это обзор полной TensorFlow программы с объяснением деталей в ходе изложения.
Это руководство использует tf.keras - высокоуровневое API для построения и тренировки моделей в TensorFlow.
# TensorFlow и tf.keras
import tensorflow as tf
from tensorflow import keras
# Вспомогательные библиотеки
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
1.9.0
Импорт Fashion MNIST набора данных
Это руководство использует Fashion MNIST набор данных, который содержит 70 000 изображений оттенков серого в 10 категориях. Изображения показывают предметы одежды в низком разрешении (28 на 28 пикселей) как на следующей картинке:
Fashion MNIST набор данных
Fashion MNIST предназначен в качестве замены для классического MNIST набора данных - часто используемого для "Hello, World" программ машинного обучения компьютерного зрения. MNIST набор данных содержит изображения рукописных цифр (0, 1, 2, и т.д.) в идентичном формате, что и изображения одежды, которые мы будем использовать в этом руководстве.
Это руководство использует Fashion MNIST за его разнообразие и поэтому это будет немного более сложная задача, чем обычный MNIST. Оба набора данных относительно небольшие и используются для проверки того, что алгоритм работает как предполагалось. Это хорошие стартовые точки для тестирования и отлаживания кода.
Мы будем использовать 60 000 изображений для тренировки сети и 10 000 изображений, чтобы оценить насколько точно сеть обучилась классифицировать изображения. Можно получить доступ к Fashion MNIST напрямую из TensorFlow, просто импортируйте и загрузите данные:
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 5us/step
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 7s 0us/step
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 5s 1us/step
Загрузка набора данных возвращает четыре NumPy массива:
- train_images и train_labels массивы - тренировочный набор - данные, которые модель использует для обучения.
- test_images и test_labels - тестовый набор, на котором модель проходит тестирование.
Изображения - 28x28 NumPy массивы со значениями пикселей в диапазоне от 0 до 255. Метки - массив цифр в диапазоне от 0 до 9. Они соответствуют классам одежды, которые представляют изображения:
Метка | Класс |
---|---|
0 | Футболка/топ (T-shirt/top) |
1 | Брюки (Trousers) |
2 | Пуловер (Pullover) |
3 | Платье (Dress) |
4 | Пальто (Coat) |
5 | Сандалия (Sandal) |
6 | Рубашка (Shirt) |
7 | Кроссовки (Sneaker) |
8 | Сумка (Bag) |
9 | Ботильоны (Ankle boot) |
Каждому изображению соответствует только одна из меток. Ввиду того что названия классов не включены в набор данных, сохраним их для дальнейшего использования при создании графиков.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Исследуем данные
Исследуем формат набора данных перед тренировкой модели. Присутствует 60 000 изображений в тренировочном наборе, каждое представлено как картинка 28 на 28 пикселей:
train_images.shape
(60000, 28, 28)
Присутствует 60 000 меток в тренировочном наборе:
len(train_labels)
60000
Каждая метка - это цифра между 0 и 9:
train_labels
array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)
Присутствует 10 000 изображений в тестовом наборе. Каждое изображение представлено как 28 на 28 пикселей:
test_images.shape
(10000, 28, 28)
И тестовый набор содержит 10 000 меток изображений:
len(test_labels)
10000
Продолжение руководства читайте в следующем посте.