TensorFlow переменные - это лучший способ представления общего постоянного состояния манипулируемого вашей программой. tf.train.Saver
конструктор добавляет save
и restore
операции в граф для всех, или указанного списка, переменных в графе. Объект Saver
предоставляет методы для запуска этих операций, указывая пути для файлов контрольных точек, в которые необходимо производить запись или из которых считывать.
Saver
восстанавливает все переменные, уже определенные в вашей модели.
TensorFlow сохраняет переменные в двоичных файлах контрольных точек, которые создают карту соотвествия имен переменных и значений тензора.
Внимание: файлы модели TensorFlow являются кодом. Будьте осторожны с ненадежным кодом.
Сохранение переменных
Создайте Saver
с помощью tf.train.Saver()
для управления всеми переменными в
модели. Например, следующий фрагмент демонстрирует, как вызвать tf.train.Saver.save
для сохранения переменных в файлах контрольных точек:
# Создаем несколько переменных.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Добавляем операцию для инициализации переменных.
init_op = tf.global_variables_initializer()
# Добавляем операции для сохранения и восстановления всех переменных.
saver = tf.train.Saver()
# Затем, запускаем модель, инициализируем переменные,
# выполняем некоторую работу, и сохраняем переменные на диск.
with tf.Session() as sess:
sess.run(init_op)
# Выполняем некоторую работу с моделью.
inc_v1.op.run()
dec_v2.op.run()
# Сохраняем переменные на диск.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
Восстановление переменных
tf.train.Saver
не только сохраняет переменные в файлы контрольных точек, он также восстанавливает переменные. Обратите внимание, что при восстановлении переменных вам не требуется инициализировать их заранее. Например, следующий фрагмент кода демонстрирует как вызвать tf.train.Saver.restore
метод для восстановления переменных из файлов контрольных точек:
tf.reset_default_graph()
# Создаем несколько переменных.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Добавляем операции для сохранения и восстановления всех переменных.
saver = tf.train.Saver()
# Затем, запускаем модель, используем saver
# для восстановления переменных с диска,
# и выполняем некторую работу с моделью.
with tf.Session() as sess:
# Восстанавливаем переменные с диска.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Проверяем значения переменных
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
Примечание: Не существует физического файла с именем /tmp/model.ckpt
. Это префикс для имен файлов, созданных для контрольной точки. Пользователи взаимодействуют только с префиксом вместо физических файлов контрольных точек.
Выбор переменных для сохранения и восстановления
Если вы не передаете аргументы в tf.train.Saver()
, saver обрабатывает все переменные в графе. Каждая переменная сохраняется под именем, которое было передано когда переменная была создана.
Иногда полезно явно указать имена переменных в файлах контрольных точек. Например, вы могли обучить модель с переменной с именем"weights"
, значение которой вы хотите восстановить в переменную с именем "params".
р>
Иногда удобно также сохранить или восстановить подмножество переменных, используемых моделью. Например, вы могли тренировать нейронную сеть с пятью слоями, и теперь вы хотите обучить новую модель с шестью слоями, которая использует существующие веса пяти обученных слоев. Вы можете использовать saver для восстановления весов только первых пяти слоев.
Вы можете легко указать имена и переменные для сохранения или загрузки, передав в tf.train.Saver()
конструктор любое из следующих:
- Список переменных (которые будут храниться под своими именами).
- Python словарь, в котором ключами являются имена для использования, а значениями являются переменные для управления.
Продолжая из примеров сохранения/восстановления, показанных ранее:
tf.reset_default_graph()
# Создаем несколько переменных.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
# Добавляем операции для сохранения и восстановления только v2,
# используя имя "v2"
saver = tf.train.Saver({"v2": v2})
# Используем saver объект после этого.
with tf.Session() as sess:
# Инициализируем v1 ввиду того, что saver не будет инициализировать.
v1.initializer.run()
saver.restore(sess, "/tmp/model.ckpt")
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
Примечания:
Вы можете создать столько объектов
Saver
, сколько хотите, чтобы сохранять и восстанавливать различные подмножества переменных модели. Та же самая переменная может быть перечислена в нескольких объектах saver; ее значение изменяется только тогда, когда методSaver.restore()
запущен.Если вы восстанавливаете только подмножество переменных модели в начале сессии, вы должны запустить операцию инициализации для других переменных.
Чтобы проверить переменные в контрольной точке, вы можете использовать
inspect_checkpoint
библиотеку, в частности, функциюprint_tensors_in_checkpoint_file
.По умолчанию
Saver
использует значениеtf.Variable.name
свойства для каждой переменной. Однако при создании объектаSaver
вы можете при желании выбирать имена переменных в файлах контрольных точек.
Проверка переменных в контрольной точке
Мы можем быстро проверить переменные в контрольной точке с помощью inspect_checkpoint
библиотеки.р>
Продолжая из примеров сохранения/восстановления, показанных ранее:
# импортируем inspect_checkpoint библиотеку
from tensorflow.python.tools import inspect_checkpoint as chkp
# печатаем все тензоры в файле контрольной точки
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt",
tensor_name='',
all_tensors=True)
# tensor_name: v1
# [ 1. 1. 1.]
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
# печатаем только тензор v1 в файле контрольной точки
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt",
tensor_name='v1',
all_tensors=False)
# tensor_name: v1
# [ 1. 1. 1.]
# печатаем только тензор v2 в файле контрольной точки
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt",
tensor_name='v2',
all_tensors=False)
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
Читайте также другие статьи по этой теме в нашем блоге: