tf.train.Saver
предоставляет методы для сохранения и восстановления моделей. tf.saved_model.simple_save
- это простой способ создания tf.saved_model
, подходящей для этой цели. Оценщики (Estimators) автоматически сохраняют и восстанавливают переменные в model_dir
.
Сохранение и восстановление моделей
Используйте SavedModel
для сохранения и загрузки вашей модели - переменных, графа и метаданных графа. Это независимый от языка, восстанавливаемый, герметичный формат сериализации, позволяющий создавать более высокоуровневые системы и инструменты для производства, потребления и преобразования TensorFlow моделей. TensorFlow предоставляет несколько способов взаимодействия с SavedModel
, включая tf.saved_model
API, tf.estimator.Estimator
и интерфейс командной строки. р>
Создание и загрузка SavedModel
Простое сохранение
Самый простой способ создать SavedModel
- это использовать tf.saved_model.simple_save
функцию:
simple_save(session,
export_dir,
inputs={"x": x, "y": y},
outputs={"z": z})
Данная команда настраивает SavedModel
, таким образом что она может быть загружена посредством TensorFlow serving и поддерживает Predict API. Чтобы получить доступ к API классификации, регрессии или множественного вывода, используйте API ручного создания SavedModel
или tf.estimator.Estimator
.р>
Создание SavedModel вручную
Если ваш вариант использования не охвачен tf.saved_model.simple_save
, используйте ручной
tf.saved_model.builder
для создания SavedModel
.
tf.saved_model.builder.SavedModelBuilder
предоставляет функциональность для сохранения нескольких MetaGraphDef
. MetaGraph - это граф потока данных, плюс связанные с ним переменные, активы (assets) и подписи. MetaGraphDef
является буфером протокола представления MetaGraph. Подпись - это набор вводов и выводов из графа.
Если активы (assets) необходимо сохранить и записать или скопировать на диск, они могут быть предоставлены когда первый MetaGraphDef
добавлен. Если несколько MetaGraphDef
связанны с активом с тем же именем, сохраняется только первая версия.
Каждый MetaGraphDef
, добавленный в SavedModel, должен быть аннотирован с пользовательскими тегами. Теги предоставляют средства для идентификации конкретных MetaGraphDef
для загрузки и восстановления вместе с общим набором переменных и активов (assets). Эти теги обычно комментируют MetaGraphDef
с его функциональностью (например, обслуживание или обучение), и, возможно, с аппаратно-специфичными аспектами (например, GPU).
Например, следующий код предлагает типичный способ использования SavedModelBuilder
для создания SavedModel:
export_dir = ...
...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph_and_variables(sess,
[tag_constants.TRAINING],
signature_def_map=foo_signatures,
assets_collection=foo_assets,
strip_default_attrs=True)
...
# Добавляем второй MetaGraphDef для вывода.
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph([tag_constants.SERVING],
strip_default_attrs=True)
...
builder.save()
Прямая совместимость с помощью strip_default_attrs = True
Следование приведенным ниже инструкциям обеспечивает прямую совместимость, только если набор операций не изменился.
tf.saved_model.builder.SavedModelBuilder
позволяет пользователям контролировать, должны ли атрибуты по умолчанию быть удалены из NodeDefs
при добавлении мета-графа в комплект SavedModel. И tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables
, и tf.saved_model.builder.SavedModelBuilder.add_meta_graph
методы принимают логический флаг strip_default_attrs
, который управляет этим поведением.
Если strip_default_attrs
имеет значение False
, экспортированный tf.MetaGraphDef
будет иметь атрибуты по умолчанию во всех его tf.NodeDef
экземплярах. Это может нарушить совместимость с последовательностью событий, такой как следующая:
- Существующая операция (Op) (
Foo
) обновлена, чтобы включить новый атрибут (T
) со значением по умолчанию (bool
) в версии 101. - Производитель модели, такой как "trainer binary" забирает это изменение (версия 101) в
OpDef
и реэкспортирует существующую модель, которая использует OpFoo
. - Потребитель модели (например, Tensorflow Serving), использующий более старую версию двоичного файла (версия 100) не имеет атрибута
T
для OpFoo
, но пытается импортировать эту модель. Потребитель модели не распознает атрибутT
вNodeDef
, который использует OpFoo
и поэтому не может загрузить модель. - Если для параметра
strip_default_attrs
установлено значение True, производители моделей могут отказаться от любых аттрибутов со значениями по умолчанию вNodeDefs
. Это помогает гарантировать, что недавно добавленные атрибуты со значениями по умолчанию не приводят к сбою попыток пользователей старых версий моделей загрузить модели, восстановленные с использованием новых тренировочных бинарных файлов.
Загрузка SavedModel в Python
Python-версия SavedModel tf.saved_model.loader
предоставляет возможность загрузки и восстановления для SavedModel. Операция load
требует следующей информации:
- Сессия, в которой нужно восстановить определение графа и переменные.
- Теги, используемые для идентификации загружаемого MetaGraphDef.
- Местоположение (каталог) SavedModel.
При загрузке подмножество переменных, активов (assets) и подписей, предоставляемых как часть определенного MetaGraphDef будет восстановлено в предоставленной сессии.
export_dir = ...
...
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
...
Загрузка SavedModel в C++
C++ версия SavedModel загрузчика предоставляет API для загрузки SavedModel из файла указанного пути, одновременно позволяя SessionOptions
и RunOptions
. Вы должны указать теги, связанные с загружаемым графом. Загруженная версия SavedModel называется SavedModelBundle
и содержит MetaGraphDef и сессию, в которой он загружен.
const string export_dir = ...
SavedModelBundle bundle;
...
LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagTrain},
&bundle);
Загрузка и обслуживание SavedModel в TensorFlow serving
Вы можете легко загружать и обслуживать SavedModel с помощью TensorFlow Serving Model Server бинарного файла. Инструкции о том, как установить сервер или собрать его, если хотите.
Как только у вас будет Model Server, запустите его с помощью:
tensorflow_model_server --port=port-numbers \
--model_name=your-model-name --model_base_path=your_model_base_path
Установите флаги port и model_name на значения по вашему выбору. Ожидается, что флаг model_base_path будет находиться в базовом каталоге с каждой версией вашей модели, находящейся в подкаталоге с цифровым именем. Если у вас есть только единственная версия вашей модели, просто поместите ее в подкаталог, например:
- Поместите модель в
/tmp/model/0001
- Установите для model_base_path значение
/tmp/model
Храните разные версии вашей модели в подкаталогах с числовыми именами общего базового каталог. Например, предположим, что базовым каталогом является /tmp/model
. Если у вас есть только одна версия вашей модели, сохраните ее в /tmp/model/0001
. Если у вас есть две версии вашей модели, сохраните вторую версию в /tmp/model/0002
и так далее. Установите флаг --model-base_path
равной базовому каталогу (/tmp/model
, в этом примере). TensorFlow Model Server будет обслуживать модель в подкаталоге с наибольшим номером этого базового каталога.
Стандартные константы
SavedModel предлагает гибкость для построения и загрузки графов TensorFlow для разных вариантов использования. Для наиболее распространенных случаев использования API-интерфейсы SavedModel предоставляют набор констант в Python и C++, которые легко использовать повторно и совместно в разных инструментах.
Стандартные теги MetaGraphDef
Вы можете использовать наборы тегов, чтобы однозначно идентифицировать MetaGraphDef
, сохраненный в SavedModel. Подмножество часто используемых тегов указывается в:
Стандартные константы SignatureDef
SignatureDef - это буфер протокола, который определяет сигнатуру вычисления, поддерживаемую графом. Обычно используемые ключи ввода, ключи вывода и имена методов определены в:
Использование SavedModel с оценщиками (Estimators)
После обучения Estimator
модели вы можете создать службу из этой модели, которая принимает запросы и возвращает результат. Вы можете запустить такую службу локально на вашем компьютере или развернуть ее в облаке.
Чтобы подготовить Estimator для обслуживания (serving), вы должны экспортировать его в стандартный SavedModel формат. Далее объясняется, как:
- Указывать выходные узлы и соответствующие API, которые могут быть обслуживаемы (Classify, Regress, или Predict).
- Экспортировать свою модель в формат SavedModel.
- Обслуживать модель с локального сервера и запрашивать прогнозы.
Подготовка вводов для обслуживания
Во время обучения input_fn()
принимает данные и готовит их для использования моделью. Во время обслуживания, аналогично, serve_input_receiver_fn()
принимает запросы на вывод и подготавливает их для модели. Эта функция имеет следующие цели:
- Чтобы добавлять заполнители в граф, который обслуживающая система будет наполнять запросами на вывод.
- Чтобы добавлять дополнительные операции, необходимые для преобразования данных из формата ввода в тензоры свойств, ожидаемые моделью.
Функция возвращает tf.estimator.export.ServingInputReceiver
объект, который упаковывает заполнители и результирующие тензоры свойств вместе.
Типичным примером является то, что запросы на вывод поступают в виде сериализованных tf.Example
, поэтому serve_input_receiver_fn()
создает один строковый заполнитель для их получения. Тогда serve_input_receiver_fn()
также отвечает за разбор tf.Example
с помощью добавления tf.parse_example
операции в граф.
При написании такого serve_input_receiver_fn()
вы должны пройти синтаксический анализ спецификация tf.parse_example
, чтобы сообщить парсеру, какие имена функций следует ожидать и как сопоставить их с Tensor
. Спецификация разбора принимает словарь из имен элементов и tf.FixedLenFeature
, tf.VarLenFeature
, и tf.SparseFeature
. Обратите внимание, что эта спецификация анализа не должна включать любые метки или весовые столбцы, так как они не будут доступны во время обслуживания - в отличие от спецификации анализа, используемой в input_fn()
во время тренировки.
В комбинации, тогда:
feature_spec = {'foo': tf.FixedLenFeature(...),
'bar': tf.VarLenFeature(...)}
def serving_input_receiver_fn():
"""Получатель ввода, который ожидает сериализованный tf.Example."""
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[default_batch_size],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features,
receiver_tensors)
tf.estimator.export.build_parsing_serving_input_receiver_fn
утилитная функция предоставляет приемник ввода для общих случаев.
Примечание. При обучении модели для обслуживания с использованием API Predict с локальным сервером, шаг разбора (parsing) не требуется, потому что модель будет получать необработанные данные свойств.
Даже если вам не требуется разбор или другая обработка ввода - то есть, если обслуживающая система будет напрямую передавать свойство Tensors
- вы все равно должны предоставить serve_input_receiver_fn()
, который создает заполнители для свойства Tensors
и передает их. tf.estimator.export.build_raw_serving_input_receiver_fn
утилита предоставляется для этого.
Если эти утилиты не соответствуют вашим потребностям, вы можете написать свою собственную serving_input_receiver_fn()
. Один из случаев, когда это может быть необходимо, это если ваша тренировочная input_fn()
включает некоторую логику предварительной обработки, которая должна быть пересмотрена во время обслуживания. Чтобы снизить риск перекоса при обучении, рекомендуется инкапсулировать такую обработку в функцию, которая затем вызывается из input_fn()
и serve_input_receiver_fn()
.
Обратите внимание, что serve_input_receiver_fn()
также определяет input часть подписи. То есть при написании serve_input_receiver_fn()
, вы должны сообщить парсеру, какие подписи ожидать и как сопоставить их с ожидаемыми входными данными вашей модели. Напротив, часть подписи output определяется моделью.
Указание выводных данных пользовательской модели
При написании пользовательской model_fn
необходимо заполнить элемент export_outputs
из tf.estimator.EstimatorSpec
. Это словарь {name: output}
, описывающий выходные сигнатуры, которые необходимо экспортировать и использовать во время обслуживания.р>
В обычном случае однократного прогноза этот словарь содержит один элемент, и name
не имеет значения. В многоголовой модели каждая голова представлена запись в этом словаре. В этом случае name
является строкой на ваш выбор, которую можно использовать для запроса определенной головы во время обслуживания.
Каждое значение output
должно быть объектом ExportOutput
, таким как tf.estimator.export.ClassificationOutput
, tf.estimator.export.RegressionOutput
или tf.estimator.export.PredictOutput
.
Эти выходные типы напрямую указывают на TensorFlow Serving API, и таким образом определяют, какие типы запросов будут выполнены.
Отметим, что в случае с несколькими заголовками будет генерироваться SignatureDef
для каждого элемента словаря export_outputs
, возвращенного из model_fn, названного с помощью тех же ключей. Эти SignatureDef
отличаются только своими выходами, как предоставлено соответствующей записью ExportOutput
. Входы всегда те, которые предоставляются serve_input_receiver_fn
. Запрос на вывод может указывать голову по имени. Одна голова должна быть названа используя signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
указывающий, какой SignatureDef
будет обслуживаться при запросе вывода не указывает ни одного.
Выполнение экспорта
Чтобы экспортировать натренированный Estimator, вызовите tf.estimator.Estimator.export_savedmodel
с базовым путем экспорта и serve_input_receiver_fn
.
estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn,
strip_default_attrs=True)
Этот метод создает новый граф, сначала вызывая serve_input_receiver_fn()
, чтобы получить свойство Tensor
, а затем вызвать model_fn()
Estimator
для генерации графа модели на основе этих свойств. Он запускает новую Session
и, по умолчанию, восстанавливает самую последнюю контрольную точку в ней. (Другая контрольная точка может быть передана, если это необходимо.) Наконец, он создает каталог экспорта с меткой времени под заданной export_dir_base
(т. е. export_dir_base/<timestamp>
), и записывает в него SavedModel, содержащую один MetaGraphDef
, сохраненный из этой Session.
Примечание. Вы несете ответственность за сбор мусора старого экспорта. В противном случае последовательныt экспортs будет накапливаться в export_dir_base
.
Обслуживание экспортированной модели локально
Для локального развертывания вы можете обслуживать свою модель, используя TensorFlow Serving, проект с открытым исходным кодом, который загружает SavedModel и предоставляет его в качестве gRPC сервиса.
Сначала установите TensorFlow Serving.
Затем соберите и запустите локальный сервер моделей, заменив $export_dir_base
на путь к SavedModel, который вы экспортировали выше:
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server \
--port=9000 --model_base_path=$export_dir_base
Теперь у вас есть сервер, который прослушивает запросы на вывод через gRPC по порту 9000.
Запрос прогнозов с локального сервера
Сервер отвечает на запросы gRPC в соответствии с PredictionService gRPC API определением сервиса. (Вложенные буферы протокола определены в различные соседних файлах).
Из API определения сервиса, gRPC фреймворк генерирует клиентские библиотеки на разных языках, обеспечивающих удаленный доступ к API. В проекте с использованием инструмента сборки Bazel, эти библиотеки создаются автоматически и предоставляются через такие зависимости (например, с использованием Python):
deps = [
"//tensorflow_serving/apis:classification_proto_py_pb2",
"//tensorflow_serving/apis:regression_proto_py_pb2",
"//tensorflow_serving/apis:predict_proto_py_pb2",
"//tensorflow_serving/apis:prediction_service_proto_py_pb2"
]
Python клиентский код может затем импортировать библиотеки:
from tensorflow_serving.apis import classification_pb2
from tensorflow_serving.apis import regression_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
Отметитим: gnution_service_pb2
определяет службу в целом и поэтому всегда требуется. Однако типичному клиенту потребуется только один из classification_pb2
, regression_pb2
или gnast_pb2
, в зависимости от типа выполняемых запросов.
Отправка запроса gRPC выполняется путем сборки буфера протокола, содержащего данные запроса и передающего их в сервисную заглушку. Обратите внимание, как буфер протокола запроса создается пустым, а затем заполняется через API сгенерированного буфера протокола.
from grpc.beta import implementations
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(
channel)
request = classification_pb2.ClassificationRequest()
example = request.input.example_list.examples.add()
example.features.feature['x'].float_list.value.extend(
image[0].astype(float))
result = stub.Classify(request, 10.0) # 10 секунд таймаут
В этом примере возвращаемым результатом является буфер протокола ClassificationResponse
.р>
Примечание. ClassificationRequest
и RegressionRequest
содержат буфер протокола tenorflow.serving.Input
, который, в свою очередь, содержит список буферов протокола tenorflow.Example
. PredictRequest
, напротив, содержит соотвествие имен свойств и значений, закодированные с помощью TensorProto
. Соответственно: при использовании API Classify
и Regress
, TensorFlow Serving передает сериализованные tf.Example
в граф, поэтому ваша serve_input_receiver_fn()
должна включать tf.parse_example()
операцию. Однако при использовании универсального Predict
API, TensorFlow Serving передает необработанные данные свойств в граф, поэтому должен быть выполнен проход через serve_input_receiver_fn()
.
Структура каталога SavedModel
Когда вы сохраняете модель в формате SavedModel, TensorFlow создает каталог SavedModel, состоящий из следующих подкаталогов и файлов:
assets/
assets.extra/
variables/
variables.data-?????-of-?????
variables.index
saved_model.pb|saved_model.pbtxt
, где:
assets
- это подпапка, содержащая вспомогательные (внешние) файлы, такие как словари. Активы (assets) копируются в расположение SavedModel и могут быть прочитаны при загрузке определенногоMetaGraphDef
.assets.extra
- это подпапка, в которую могут входить высокоуровневые библиотеки и пользователи могут добавлять свои собственные активы, которые сосуществуют с моделью, но не загружены в граф. Эта подпапка не управляется библиотеками SavedModel.variable
- это подпапка, содержащая выходные данные изtf.train.Saver
.saved_model.pb
или <код>saved_model.pbtxt- это буфер протокола SavedModel. Он включает определения графа в виде буферов протоколаMetaGraphDef
.
Одна SavedModel может представлять несколько графов. В этом случае все графы в SavedModel совместно используют единый набор контрольных точек (переменных) и активов. Например, следующая диаграмма показывает одну SavedModel, содержащую три элемента MetaGraphDef
, все три из которых используют один и тот же набор контрольных точек и активов (assets):
Каждый граф связан с определенным набором тегов, который позволяет идентифицировать их во время загрузки или восстановления.
Читайте также другие статьи по этой теме в блоге: