пятница, 14 декабря 2018 г.

TensorFlow Core: сохранение и восстановление переменных

TensorFlow переменные - это лучший способ представления общего постоянного состояния манипулируемого вашей программой. tf.train.Saver конструктор добавляет save и restore операции в граф для всех, или указанного списка, переменных в графе. Объект Saver предоставляет методы для запуска этих операций, указывая пути для файлов контрольных точек, в которые необходимо производить запись или из которых считывать.

Saver восстанавливает все переменные, уже определенные в вашей модели.

TensorFlow сохраняет переменные в двоичных файлах контрольных точек, которые создают карту соотвествия имен переменных и значений тензора.

Внимание: файлы модели TensorFlow являются кодом. Будьте осторожны с ненадежным кодом.

Сохранение переменных

Создайте Saver с помощью tf.train.Saver() для управления всеми переменными в модели. Например, следующий фрагмент демонстрирует, как вызвать tf.train.Saver.save для сохранения переменных в файлах контрольных точек:

# Создаем несколько переменных.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Добавляем операцию для инициализации переменных.
init_op = tf.global_variables_initializer()

# Добавляем операции для сохранения и восстановления всех переменных.
saver = tf.train.Saver()

# Затем, запускаем модель, инициализируем переменные, 
# выполняем некоторую работу, и сохраняем переменные на диск.
with tf.Session() as sess:
  sess.run(init_op)
  # Выполняем некоторую работу с моделью.
  inc_v1.op.run()
  dec_v2.op.run()
  # Сохраняем переменные на диск.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Восстановление переменных

tf.train.Saver не только сохраняет переменные в файлы контрольных точек, он также восстанавливает переменные. Обратите внимание, что при восстановлении переменных вам не требуется инициализировать их заранее. Например, следующий фрагмент кода демонстрирует как вызвать tf.train.Saver.restore метод для восстановления переменных из файлов контрольных точек:

tf.reset_default_graph()

# Создаем несколько переменных.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Добавляем операции для сохранения и восстановления всех переменных.
saver = tf.train.Saver()

# Затем, запускаем модель, используем saver 
# для восстановления переменных с диска,
# и выполняем некторую работу с моделью.
with tf.Session() as sess:
  # Восстанавливаем переменные с диска.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Проверяем значения переменных
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Примечание: Не существует физического файла с именем /tmp/model.ckpt. Это префикс для имен файлов, созданных для контрольной точки. Пользователи взаимодействуют только с префиксом вместо физических файлов контрольных точек.

Выбор переменных для сохранения и восстановления

Если вы не передаете аргументы в tf.train.Saver(), saver обрабатывает все переменные в графе. Каждая переменная сохраняется под именем, которое было передано когда переменная была создана.

Иногда полезно явно указать имена переменных в файлах контрольных точек. Например, вы могли обучить модель с переменной с именем"weights", значение которой вы хотите восстановить в переменную с именем "params".

Иногда удобно также сохранить или восстановить подмножество переменных, используемых моделью. Например, вы могли тренировать нейронную сеть с пятью слоями, и теперь вы хотите обучить новую модель с шестью слоями, которая использует существующие веса пяти обученных слоев. Вы можете использовать saver для восстановления весов только первых пяти слоев.

Вы можете легко указать имена и переменные для сохранения или загрузки, передав в tf.train.Saver() конструктор любое из следующих:

  • Список переменных (которые будут храниться под своими именами).
  • Python словарь, в котором ключами являются имена для использования, а значениями являются переменные для управления.

Продолжая из примеров сохранения/восстановления, показанных ранее:

tf.reset_default_graph()
# Создаем несколько переменных.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)

# Добавляем операции для сохранения и восстановления только v2, 
# используя имя "v2"
saver = tf.train.Saver({"v2": v2})

# Используем saver объект после этого.
with tf.Session() as sess:
  # Инициализируем v1 ввиду того, что saver не будет инициализировать.
  v1.initializer.run()
  saver.restore(sess, "/tmp/model.ckpt")

  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Примечания:

  • Вы можете создать столько объектов Saver, сколько хотите, чтобы сохранять и восстанавливать различные подмножества переменных модели. Та же самая переменная может быть перечислена в нескольких объектах saver; ее значение изменяется только тогда, когда метод Saver.restore() запущен.

  • Если вы восстанавливаете только подмножество переменных модели в начале сессии, вы должны запустить операцию инициализации для других переменных.

  • Чтобы проверить переменные в контрольной точке, вы можете использовать inspect_checkpoint библиотеку, в частности, функцию print_tensors_in_checkpoint_file.

  • По умолчанию Saver использует значение tf.Variable.name свойства для каждой переменной. Однако при создании объекта Saver вы можете при желании выбирать имена переменных в файлах контрольных точек.

Проверка переменных в контрольной точке

Мы можем быстро проверить переменные в контрольной точке с помощью inspect_checkpoint библиотеки.

Продолжая из примеров сохранения/восстановления, показанных ранее:

# импортируем inspect_checkpoint библиотеку
from tensorflow.python.tools import inspect_checkpoint as chkp

# печатаем все тензоры в файле контрольной точки
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", 
                                      tensor_name='', 
                                      all_tensors=True)

# tensor_name:  v1
# [ 1.  1.  1.]
# tensor_name:  v2
# [-1. -1. -1. -1. -1.]

# печатаем только тензор v1 в файле контрольной точки
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", 
                                      tensor_name='v1', 
                                      all_tensors=False)

# tensor_name:  v1
# [ 1.  1.  1.]

# печатаем только тензор v2 в файле контрольной точки
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", 
                                      tensor_name='v2', 
                                      all_tensors=False)

# tensor_name:  v2
# [-1. -1. -1. -1. -1.]


Читайте также другие статьи по этой теме в нашем блоге:

Основы TensorFlow Core

TensorFlow Core: тензоры (tensors)

TensorFlow Core: переменные (variables)