четверг, 23 августа 2018 г.

Регуляризация для упрощения модели: L2 регуляризация

Посмотрим на следующую кривую генерализации, которая показывает потерю для тренировочного и валидационного наборов в зависимости от количества тренировочных итераций.

Потеря на тренировочном и валидационном наборе

На предыдущем графике показана модель, в которой тренировочная потеря постепенно уменьшается, но валидационная потеря в итоге увеличивается. Другими словами, генерализационная кривая показывает, что модель переобучилась данным в тренировочном наборе. Используя бритву Оккама в машинном обучении, возможно мы можем предупредить переобучение штрафуя сложные модели, этот принцип называется регуляризация.

Другими словами, вместо простой цели уменьшить потерю (эмпирическая минимизация риска):

minimize(Loss(Data|Model))

мы будем минизировать потерю+сложность, это структурная минимизация риска:

minimize(Loss(Data|Model) + complexity(Model))

Наш алгоритм тренировочной оптимизации теперь функция двух членов: потеря, которая измеряет насколько хорошо модель справляется с данными, и регуляризация, которая измеряет сложность модели.

Сложность модели можно представлять двумя путями:

  • Сложность модели - функция весов всех свойств модели.
  • Сложность модели - функция общего количества свойств с ненулевыми весами.

Если сложность модели это функция весов, тогда вес свойства с большим абсолютным значением более сложен, чем вес свойства с небольшим абсолютным значением.

Мы можем сделать сложность подсчитываемой, используя формулу L2 регуляризации, которая определяет регуляризацию как сумму квадратов всех весов свойств:

В этой формуле веса близкие к нулю слабо воздействуют на сложность модели, в то время как веса с экстремальными значениями могут иметь огромное влияние.

Например, линейная модель со следующими весами:

Имеет L2 регуляризацию 26.915:

Но w3 (выделенное жирным шрифтом), с значением, возведенным в квадрат, равным 25, составляет почти всю величину сложности. Сумма квадратов других пяти весов добавляет всего 1.915 к L2 регуляризации.