Implement apply gradients for keras optimizer v2.
PiperOrigin-RevId: 219189655
This commit is contained in:
parent
d09c687897
commit
98fb3b9243
@ -471,6 +471,28 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "keras_optimizer_v2_test",
|
||||||
|
srcs = ["keras_optimizer_v2_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":combinations",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/contrib/optimizer_v2:training",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
"//tensorflow/python/estimator:estimator_py",
|
||||||
|
"//tensorflow/python/feature_column",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:summary",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"multi_and_single_gpu",
|
||||||
|
"no_pip",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "estimator_training_test",
|
name = "estimator_training_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
|
237
tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
Normal file
237
tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests that show that DistributionStrategy works with canned Estimator."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.contrib.distribute.python import combinations
|
||||||
|
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.estimator import run_config
|
||||||
|
from tensorflow.python.estimator import training
|
||||||
|
from tensorflow.python.estimator.canned import dnn_linear_combined
|
||||||
|
from tensorflow.python.estimator.canned import prediction_keys
|
||||||
|
from tensorflow.python.estimator.export import export
|
||||||
|
from tensorflow.python.estimator.inputs import numpy_io
|
||||||
|
from tensorflow.python.feature_column import feature_column
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.summary.writer import writer_cache
|
||||||
|
|
||||||
|
|
||||||
|
class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._model_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
def dataset_input_fn(self, x, y, batch_size):
|
||||||
|
|
||||||
|
def input_fn():
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||||
|
dataset = dataset.repeat(1).batch(batch_size)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
return input_fn
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=['graph'],
|
||||||
|
distribution=[
|
||||||
|
combinations.one_device_strategy,
|
||||||
|
combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
combinations.mirrored_strategy_with_two_gpus
|
||||||
|
],
|
||||||
|
use_train_and_evaluate=[True, False]))
|
||||||
|
def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate):
|
||||||
|
label_dimension = 2
|
||||||
|
input_dimension = label_dimension
|
||||||
|
batch_size = 10
|
||||||
|
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
||||||
|
data = data.reshape(batch_size, label_dimension)
|
||||||
|
train_input_fn = self.dataset_input_fn(
|
||||||
|
x={'x': data},
|
||||||
|
y=data,
|
||||||
|
batch_size=batch_size // len(distribution.worker_devices))
|
||||||
|
eval_input_fn = self.dataset_input_fn(
|
||||||
|
x={'x': data},
|
||||||
|
y=data,
|
||||||
|
batch_size=batch_size // len(distribution.worker_devices))
|
||||||
|
predict_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={'x': data}, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
linear_feature_columns = [
|
||||||
|
feature_column.numeric_column('x', shape=(input_dimension,))
|
||||||
|
]
|
||||||
|
dnn_feature_columns = [
|
||||||
|
feature_column.numeric_column('x', shape=(input_dimension,))
|
||||||
|
]
|
||||||
|
feature_columns = linear_feature_columns + dnn_feature_columns
|
||||||
|
session_config = config_pb2.ConfigProto(
|
||||||
|
log_device_placement=True, allow_soft_placement=True)
|
||||||
|
estimator = dnn_linear_combined.DNNLinearCombinedRegressor(
|
||||||
|
linear_feature_columns=linear_feature_columns,
|
||||||
|
dnn_hidden_units=(2, 2),
|
||||||
|
dnn_feature_columns=dnn_feature_columns,
|
||||||
|
label_dimension=label_dimension,
|
||||||
|
model_dir=self._model_dir,
|
||||||
|
dnn_optimizer=adam.Adam(0.001),
|
||||||
|
linear_optimizer=adam.Adam(0.001),
|
||||||
|
config=run_config.RunConfig(
|
||||||
|
train_distribute=distribution,
|
||||||
|
eval_distribute=distribution,
|
||||||
|
session_config=session_config))
|
||||||
|
|
||||||
|
num_steps = 2
|
||||||
|
if use_train_and_evaluate:
|
||||||
|
scores, _ = training.train_and_evaluate(
|
||||||
|
estimator, training.TrainSpec(train_input_fn, max_steps=num_steps),
|
||||||
|
training.EvalSpec(eval_input_fn))
|
||||||
|
else:
|
||||||
|
estimator.train(train_input_fn, steps=num_steps)
|
||||||
|
scores = estimator.evaluate(eval_input_fn)
|
||||||
|
|
||||||
|
self.assertIn('loss', six.iterkeys(scores))
|
||||||
|
|
||||||
|
predictions = np.array([
|
||||||
|
x[prediction_keys.PredictionKeys.PREDICTIONS]
|
||||||
|
for x in estimator.predict(predict_input_fn)
|
||||||
|
])
|
||||||
|
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
|
||||||
|
|
||||||
|
feature_spec = feature_column.make_parse_example_spec(feature_columns)
|
||||||
|
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
|
||||||
|
feature_spec)
|
||||||
|
export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
|
||||||
|
serving_input_receiver_fn)
|
||||||
|
self.assertTrue(gfile.Exists(export_dir))
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if self._model_dir:
|
||||||
|
writer_cache.FileWriterCache.clear()
|
||||||
|
shutil.rmtree(self._model_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class MirroredStrategyOptimizerV2Test(test.TestCase):
|
||||||
|
|
||||||
|
def testKerasOptimizerWithUnequalInput(self):
|
||||||
|
if context.num_gpus() < 1:
|
||||||
|
self.skipTest('Not enough GPUs.')
|
||||||
|
|
||||||
|
def create_fn(device_id):
|
||||||
|
var = variables.Variable(
|
||||||
|
2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM)
|
||||||
|
# grad for cpu is 1, grad for gpu is 2, avg grad is 1.5.
|
||||||
|
loss = (device_id + 1) * var
|
||||||
|
optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2)
|
||||||
|
train_op = optimizer.minimize(loss, var_list=[var])
|
||||||
|
m = optimizer.get_slot(var, 'm')
|
||||||
|
v = optimizer.get_slot(var, 'v')
|
||||||
|
return (var, m, v, train_op, optimizer.iteration)
|
||||||
|
|
||||||
|
devices = ['/device:GPU:0', '/device:CPU:0']
|
||||||
|
dist = mirrored_strategy.MirroredStrategy(devices)
|
||||||
|
with dist.scope():
|
||||||
|
(var, m, v, op, counter) = dist.call_for_each_replica(
|
||||||
|
create_fn, dist.worker_device_index, run_concurrently=False)
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
var_val = [2.0, 2.0, 2.0]
|
||||||
|
self.assertAllClose(
|
||||||
|
var_val,
|
||||||
|
self.evaluate(
|
||||||
|
[dist.read_var(var),
|
||||||
|
var.get(devices[0]),
|
||||||
|
var.get(devices[1])]))
|
||||||
|
self.assertAllClose([0, 0, 0],
|
||||||
|
self.evaluate([
|
||||||
|
dist.read_var(counter),
|
||||||
|
counter.get(devices[0]),
|
||||||
|
counter.get(devices[1])
|
||||||
|
]))
|
||||||
|
|
||||||
|
train_op = dist.unwrap(op)
|
||||||
|
self.evaluate(train_op)
|
||||||
|
# m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2
|
||||||
|
m_val = [1.2, 1.2, 1.2]
|
||||||
|
# assert slot variables in both replicas are the same.
|
||||||
|
self.assertAllClose(
|
||||||
|
m_val,
|
||||||
|
self.evaluate(
|
||||||
|
[dist.read_var(m),
|
||||||
|
m.get(devices[0]),
|
||||||
|
m.get(devices[1])]))
|
||||||
|
# v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25
|
||||||
|
v_val = [1.8, 1.8, 1.8]
|
||||||
|
self.assertAllClose(
|
||||||
|
v_val,
|
||||||
|
self.evaluate(
|
||||||
|
[dist.read_var(v),
|
||||||
|
v.get(devices[0]),
|
||||||
|
v.get(devices[1])]))
|
||||||
|
# var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1)
|
||||||
|
# = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8
|
||||||
|
var_val = [1.99, 1.99, 1.99]
|
||||||
|
self.assertAllClose(
|
||||||
|
var_val,
|
||||||
|
self.evaluate(
|
||||||
|
[dist.read_var(var),
|
||||||
|
var.get(devices[0]),
|
||||||
|
var.get(devices[1])]))
|
||||||
|
self.assertAllClose([1, 1, 1],
|
||||||
|
self.evaluate([
|
||||||
|
dist.read_var(counter),
|
||||||
|
counter.get(devices[0]),
|
||||||
|
counter.get(devices[1])
|
||||||
|
]))
|
||||||
|
|
||||||
|
self.evaluate(train_op)
|
||||||
|
# m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5
|
||||||
|
m_val = [1.44, 1.44, 1.44]
|
||||||
|
self.assertAllClose(
|
||||||
|
m_val,
|
||||||
|
self.evaluate(
|
||||||
|
[dist.read_var(m),
|
||||||
|
m.get(devices[0]),
|
||||||
|
m.get(devices[1])]))
|
||||||
|
# v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25
|
||||||
|
v_val = [2.16, 2.16, 2.16]
|
||||||
|
self.assertAllClose(
|
||||||
|
v_val,
|
||||||
|
self.evaluate(
|
||||||
|
[dist.read_var(v),
|
||||||
|
v.get(devices[0]),
|
||||||
|
v.get(devices[1])]))
|
||||||
|
self.assertAllClose([2, 2, 2],
|
||||||
|
self.evaluate([
|
||||||
|
dist.read_var(counter),
|
||||||
|
counter.get(devices[0]),
|
||||||
|
counter.get(devices[1])
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -12,6 +12,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
|||||||
py_library(
|
py_library(
|
||||||
name = "optimizer_v2",
|
name = "optimizer_v2",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"adam.py",
|
||||||
"gradient_descent.py",
|
"gradient_descent.py",
|
||||||
"optimizer_v2.py",
|
"optimizer_v2.py",
|
||||||
],
|
],
|
||||||
|
124
tensorflow/python/keras/optimizer_v2/adam.py
Normal file
124
tensorflow/python/keras/optimizer_v2/adam.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Adam for TensorFlow."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.training import training_ops
|
||||||
|
|
||||||
|
|
||||||
|
class Adam(optimizer_v2.OptimizerV2):
|
||||||
|
"""Optimizer that implements the Adam algorithm.
|
||||||
|
|
||||||
|
Adam optimization is a stochastic gradient descent method that is based on
|
||||||
|
adaptive estimation of first-order and second-order moments. According to the
|
||||||
|
reference, the method is 'computationally efficient, has little memory
|
||||||
|
requirement, invariant to diagonal rescaling of gradients, and is well suited
|
||||||
|
for problems that are large in terms of data/parameters'.
|
||||||
|
|
||||||
|
# References
|
||||||
|
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||||
|
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
learning_rate=0.001,
|
||||||
|
beta_1=0.9,
|
||||||
|
beta_2=0.999,
|
||||||
|
epsilon=1e-8,
|
||||||
|
name='Adam'):
|
||||||
|
r"""Construct a new Adam optimizer.
|
||||||
|
|
||||||
|
Initialization:
|
||||||
|
|
||||||
|
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
|
||||||
|
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
|
||||||
|
$$t := 0 \text{(Initialize timestep)}$$
|
||||||
|
|
||||||
|
The update rule for `variable` with gradient `g` uses an optimization
|
||||||
|
described at the end of section2 of the paper:
|
||||||
|
|
||||||
|
$$t := t + 1$$
|
||||||
|
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||||
|
|
||||||
|
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||||
|
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||||
|
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||||
|
|
||||||
|
The default value of 1e-8 for epsilon might not be a good default in
|
||||||
|
general. For example, when training an Inception network on ImageNet a
|
||||||
|
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||||
|
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||||
|
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||||
|
hat" in the paper.
|
||||||
|
|
||||||
|
The sparse implementation of this algorithm (used when the gradient is an
|
||||||
|
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||||
|
lookup in the forward pass) does apply momentum to variable slices even if
|
||||||
|
they were not used in the forward pass (meaning they have a gradient equal
|
||||||
|
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||||
|
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||||
|
behavior (in contrast to some momentum implementations which ignore momentum
|
||||||
|
unless a variable slice was actually used).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||||
|
beta_1: A float value or a constant float tensor. The exponential decay
|
||||||
|
rate for the 1st moment estimates.
|
||||||
|
beta_2: A float value or a constant float tensor. The exponential decay
|
||||||
|
rate for the 2nd moment estimates.
|
||||||
|
epsilon: A small constant for numerical stability. This epsilon is
|
||||||
|
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||||
|
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||||
|
name: Optional name for the operations created when applying gradients.
|
||||||
|
Defaults to "Adam". @compatibility(eager) When eager execution is
|
||||||
|
enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be
|
||||||
|
a callable that takes no arguments and returns the actual value to use.
|
||||||
|
This can be useful for changing these values across different
|
||||||
|
invocations of optimizer functions. @end_compatibility
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(Adam, self).__init__(name)
|
||||||
|
self._lr = learning_rate
|
||||||
|
self._beta_1 = beta_1
|
||||||
|
self._beta_2 = beta_2
|
||||||
|
self._epsilon = epsilon
|
||||||
|
|
||||||
|
def _create_slots(self, var_list):
|
||||||
|
# Create slots for the first and second moments.
|
||||||
|
for var in var_list:
|
||||||
|
self.add_slot(var, 'm')
|
||||||
|
self.add_slot(var, 'v')
|
||||||
|
|
||||||
|
def _resource_apply_dense(self, grad, var):
|
||||||
|
m = self.get_slot(var, 'm')
|
||||||
|
v = self.get_slot(var, 'v')
|
||||||
|
# TODO(tanzheny): let optimizer have its own step counter, and let
|
||||||
|
# beta1_power and beta2_power depend on it.
|
||||||
|
return training_ops.resource_apply_adam(
|
||||||
|
var.handle,
|
||||||
|
m.handle,
|
||||||
|
v.handle,
|
||||||
|
math_ops.cast(self._beta_1, grad.dtype.base_dtype),
|
||||||
|
math_ops.cast(self._beta_2, grad.dtype.base_dtype),
|
||||||
|
math_ops.cast(self._lr, grad.dtype.base_dtype),
|
||||||
|
math_ops.cast(self._beta_1, grad.dtype.base_dtype),
|
||||||
|
math_ops.cast(self._beta_2, grad.dtype.base_dtype),
|
||||||
|
math_ops.cast(self._epsilon, grad.dtype.base_dtype),
|
||||||
|
grad,
|
||||||
|
use_locking=self._use_locking)
|
@ -64,13 +64,13 @@ class GradientDescentOptimizerTest(test.TestCase):
|
|||||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||||
sgd_op = gradient_descent.SGD(3.0).apply_gradients(
|
sgd = gradient_descent.SGD(3.0)
|
||||||
zip([grads0, grads1], [var0, var1]))
|
sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
# TODO(apassos) calling initialize_resources on all resources here
|
# TODO(apassos) calling initialize_resources on all resources here
|
||||||
# doesn't work because the sessions and graph are reused across unit
|
# doesn't work because the sessions and graph are reused across unit
|
||||||
# tests and this would mean trying to reinitialize variables. Figure out
|
# tests and this would mean trying to reinitialize variables. Figure out
|
||||||
# a long-term solution for this.
|
# a long-term solution for this.
|
||||||
resources.initialize_resources([var0, var1]).run()
|
resources.initialize_resources([var0, var1, sgd.iteration]).run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
||||||
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
||||||
@ -90,13 +90,13 @@ class GradientDescentOptimizerTest(test.TestCase):
|
|||||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||||
lr = lambda: 3.0
|
lr = lambda: 3.0
|
||||||
sgd_op = gradient_descent.SGD(lr).apply_gradients(
|
sgd = gradient_descent.SGD(lr)
|
||||||
zip([grads0, grads1], [var0, var1]))
|
sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
# TODO(apassos) calling initialize_resources on all resources here
|
# TODO(apassos) calling initialize_resources on all resources here
|
||||||
# doesn't work because the sessions and graph are reused across unit
|
# doesn't work because the sessions and graph are reused across unit
|
||||||
# tests and this would mean trying to reinitialize variables. Figure out
|
# tests and this would mean trying to reinitialize variables. Figure out
|
||||||
# a long-term solution for this.
|
# a long-term solution for this.
|
||||||
resources.initialize_resources([var0, var1]).run()
|
resources.initialize_resources([var0, var1, sgd.iteration]).run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
||||||
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
||||||
@ -116,12 +116,13 @@ class GradientDescentOptimizerTest(test.TestCase):
|
|||||||
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
|
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
|
||||||
pred = math_ops.matmul(var0, x) + var1
|
pred = math_ops.matmul(var0, x) + var1
|
||||||
loss = pred * pred
|
loss = pred * pred
|
||||||
sgd_op = gradient_descent.SGD(1.0).minimize(loss)
|
sgd = gradient_descent.SGD(1.0)
|
||||||
|
sgd_op = sgd.minimize(loss, [var0, var1])
|
||||||
# TODO(apassos) calling initialize_resources on all resources here
|
# TODO(apassos) calling initialize_resources on all resources here
|
||||||
# doesn't work because the sessions and graph are reused across unit
|
# doesn't work because the sessions and graph are reused across unit
|
||||||
# tests and this would mean trying to reinitialize variables. Figure out
|
# tests and this would mean trying to reinitialize variables. Figure out
|
||||||
# a long-term solution for this.
|
# a long-term solution for this.
|
||||||
resources.initialize_resources([var0, var1]).run()
|
resources.initialize_resources([var0, var1, sgd.iteration]).run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
|
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
|
||||||
self.assertAllCloseAccordingToType([3.0], var1.eval())
|
self.assertAllCloseAccordingToType([3.0], var1.eval())
|
||||||
@ -143,7 +144,7 @@ class GradientDescentOptimizerTest(test.TestCase):
|
|||||||
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
|
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
|
||||||
pred += var1
|
pred += var1
|
||||||
loss = pred * pred
|
loss = pred * pred
|
||||||
sgd_op = gradient_descent.SGD(1.0).minimize(loss)
|
sgd_op = gradient_descent.SGD(1.0).minimize(loss, [var0, var1])
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
|
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
|
||||||
@ -193,25 +194,24 @@ class GradientDescentOptimizerTest(test.TestCase):
|
|||||||
def testWithGlobalStep(self):
|
def testWithGlobalStep(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
global_step = variables.Variable(0, trainable=False)
|
|
||||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||||
sgd_op = gradient_descent.SGD(3.0).apply_gradients(
|
sgd = gradient_descent.SGD(3.0)
|
||||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
||||||
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
|
||||||
# Run 1 step of sgd
|
# Run 1 step of sgd
|
||||||
sgd_op.run()
|
sgd_op.run()
|
||||||
# Validate updated params and global_step
|
# Validate updated params and optimizer iterations.
|
||||||
self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
|
self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
|
||||||
var0.eval())
|
var0.eval())
|
||||||
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
|
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
|
||||||
var1.eval())
|
var1.eval())
|
||||||
self.assertAllCloseAccordingToType(1, global_step.eval())
|
self.assertAllCloseAccordingToType(1, sgd.iteration.eval())
|
||||||
|
|
||||||
def testSparseBasic(self):
|
def testSparseBasic(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -22,7 +22,20 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.keras import initializers
|
||||||
|
from tensorflow.python.keras.engine import base_layer
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import gradients
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import distribution_strategy_context
|
||||||
from tensorflow.python.training import optimizer as optimizer_v1
|
from tensorflow.python.training import optimizer as optimizer_v1
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
class OptimizerV2(optimizer_v1.Optimizer):
|
class OptimizerV2(optimizer_v1.Optimizer):
|
||||||
@ -77,29 +90,6 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
opt.apply_gradients(capped_grads_and_vars)
|
opt.apply_gradients(capped_grads_and_vars)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Gating Gradients
|
|
||||||
|
|
||||||
Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
|
|
||||||
argument that controls the degree of parallelism during the application of
|
|
||||||
the gradients.
|
|
||||||
|
|
||||||
The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
|
|
||||||
|
|
||||||
<b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
|
|
||||||
the maximum parallelism in execution, at the cost of some non-reproducibility
|
|
||||||
in the results. For example the two gradients of `matmul` depend on the input
|
|
||||||
values: With `GATE_NONE` one of the gradients could be applied to one of the
|
|
||||||
inputs _before_ the other gradient is computed resulting in non-reproducible
|
|
||||||
results.
|
|
||||||
|
|
||||||
<b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
|
|
||||||
they are used. This prevents race conditions for Ops that generate gradients
|
|
||||||
for multiple inputs where the gradients depend on the inputs.
|
|
||||||
|
|
||||||
<b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
|
|
||||||
before any one of them is used. This provides the least parallelism but can
|
|
||||||
be useful if you want to process all gradients before applying any of them.
|
|
||||||
|
|
||||||
### Slots
|
### Slots
|
||||||
|
|
||||||
Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
|
Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
|
||||||
@ -111,11 +101,6 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
This can be useful if you want to log debug a training algorithm, report stats
|
This can be useful if you want to log debug a training algorithm, report stats
|
||||||
about the slots, etc.
|
about the slots, etc.
|
||||||
|
|
||||||
### Non-slot variables
|
|
||||||
|
|
||||||
Some optimizer subclasses, such as `AdamOptimizer` have variables that
|
|
||||||
are not associated with the variables to train, just the step itself.
|
|
||||||
|
|
||||||
### Hyper parameters
|
### Hyper parameters
|
||||||
|
|
||||||
These are arguments passed to the optimizer subclass constructor
|
These are arguments passed to the optimizer subclass constructor
|
||||||
@ -124,18 +109,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
callables. If they are callable, the callable will be called during
|
callables. If they are callable, the callable will be called during
|
||||||
`apply_gradients()` to get the value for the hyper parameter.
|
`apply_gradients()` to get the value for the hyper parameter.
|
||||||
|
|
||||||
### State
|
|
||||||
|
|
||||||
Internal methods are passed a `state` argument with the correct
|
|
||||||
values to use for the slot and non-slot variables, and the hyper
|
|
||||||
parameters.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Values for gate_gradients.
|
|
||||||
GATE_NONE = 0
|
|
||||||
GATE_OP = 1
|
|
||||||
GATE_GRAPH = 2
|
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
"""Create a new Optimizer.
|
"""Create a new Optimizer.
|
||||||
|
|
||||||
@ -145,6 +120,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
you should be able to use the _set_hyper()/state.get_hyper()
|
you should be able to use the _set_hyper()/state.get_hyper()
|
||||||
facility instead.
|
facility instead.
|
||||||
|
|
||||||
|
This class in stateful and thread-compatible.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: A non-empty string. The name to use for accumulators created
|
name: A non-empty string. The name to use for accumulators created
|
||||||
for the optimizer.
|
for the optimizer.
|
||||||
@ -157,6 +134,192 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
self._use_locking = True
|
self._use_locking = True
|
||||||
super(OptimizerV2, self).__init__(self._use_locking, name)
|
super(OptimizerV2, self).__init__(self._use_locking, name)
|
||||||
self._hyper = {}
|
self._hyper = {}
|
||||||
|
self._slots = {}
|
||||||
|
self._prepared = False
|
||||||
|
|
||||||
|
def minimize(self,
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
aggregation_method=None,
|
||||||
|
colocate_gradients_with_ops=False,
|
||||||
|
name=None,
|
||||||
|
grad_loss=None):
|
||||||
|
"""Add operations to minimize `loss` by updating `var_list`.
|
||||||
|
|
||||||
|
This method simply combines calls `compute_gradients()` and
|
||||||
|
`apply_gradients()`. If you want to process the gradient before applying
|
||||||
|
them call `compute_gradients()` and `apply_gradients()` explicitly instead
|
||||||
|
of using this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: A `Tensor` containing the value to minimize.
|
||||||
|
var_list: list or tuple of `Variable` objects to update to minimize
|
||||||
|
`loss`.
|
||||||
|
aggregation_method: Specifies the method used to combine gradient terms.
|
||||||
|
Valid values are defined in the class `AggregationMethod`.
|
||||||
|
colocate_gradients_with_ops: If True, try colocating gradients with the
|
||||||
|
corresponding op.
|
||||||
|
name: Optional name for the returned operation.
|
||||||
|
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An Operation that updates the variables in `var_list`. If `global_step`
|
||||||
|
was not `None`, that operation also increments `global_step`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If some of the variables are not `Variable` objects.
|
||||||
|
|
||||||
|
@compatibility(eager)
|
||||||
|
When eager execution is enabled, `loss` should be a Python function that
|
||||||
|
takes no arguments and computes the value to be minimized. Minimization (and
|
||||||
|
gradient computation) is done with respect to the elements of `var_list` if
|
||||||
|
not None, else with respect to any trainable variables created during the
|
||||||
|
execution of the `loss` function. `gate_gradients`, `aggregation_method`,
|
||||||
|
`colocate_gradients_with_ops` and `grad_loss` are ignored when eager
|
||||||
|
execution is enabled.
|
||||||
|
@end_compatibility
|
||||||
|
"""
|
||||||
|
grads_and_vars = self.compute_gradients(
|
||||||
|
loss,
|
||||||
|
var_list=var_list,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
colocate_gradients_with_ops=colocate_gradients_with_ops,
|
||||||
|
grad_loss=grad_loss)
|
||||||
|
|
||||||
|
return self.apply_gradients(grads_and_vars, name=name)
|
||||||
|
|
||||||
|
def compute_gradients(self,
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
aggregation_method=None,
|
||||||
|
colocate_gradients_with_ops=False,
|
||||||
|
grad_loss=None,
|
||||||
|
stop_gradients=None):
|
||||||
|
"""Compute gradients of `loss` for the variables in `var_list`.
|
||||||
|
|
||||||
|
This is the first part of `minimize()`. It returns a list
|
||||||
|
of (gradient, variable) pairs where "gradient" is the gradient
|
||||||
|
for "variable". Note that "gradient" can be a `Tensor`, an
|
||||||
|
`IndexedSlices`, or `None` if there is no gradient for the
|
||||||
|
given variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: A Tensor containing the value to minimize or a callable taking no
|
||||||
|
arguments which returns the value to minimize. When eager execution is
|
||||||
|
enabled it must be a callable.
|
||||||
|
var_list: Optional list or tuple of `tf.Variable` to update to minimize
|
||||||
|
`loss`. Defaults to the list of variables collected in the graph under
|
||||||
|
the key `GraphKeys.TRAINABLE_VARIABLES`.
|
||||||
|
aggregation_method: Specifies the method used to combine gradient terms.
|
||||||
|
Valid values are defined in the class `AggregationMethod`.
|
||||||
|
colocate_gradients_with_ops: If True, try colocating gradients with the
|
||||||
|
corresponding op.
|
||||||
|
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
|
||||||
|
stop_gradients: Optional. A Tensor or list of tensors not to differentiate
|
||||||
|
through.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of (gradient, variable) pairs. Variable is always present, but
|
||||||
|
gradient can be `None`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `var_list` contains anything else than `Variable` objects.
|
||||||
|
ValueError: If some arguments are invalid, or var_list is None.
|
||||||
|
RuntimeError: If called with eager execution enabled and `loss` is
|
||||||
|
not callable.
|
||||||
|
|
||||||
|
@compatibility(eager)
|
||||||
|
When eager execution is enabled, `aggregation_method`, and
|
||||||
|
`colocate_gradients_with_ops` are ignored.
|
||||||
|
@end_compatibility
|
||||||
|
"""
|
||||||
|
var_list = nest.flatten(var_list)
|
||||||
|
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
|
||||||
|
if callable(loss):
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(var_list)
|
||||||
|
loss_value = loss()
|
||||||
|
grads = tape.gradient(loss_value, var_list, grad_loss)
|
||||||
|
else:
|
||||||
|
if context.executing_eagerly():
|
||||||
|
raise RuntimeError("`loss` passed to Optimizer.compute_gradients "
|
||||||
|
"should be a function when eager execution is "
|
||||||
|
"enabled.")
|
||||||
|
self._assert_valid_dtypes([loss])
|
||||||
|
if grad_loss is not None:
|
||||||
|
self._assert_valid_dtypes([grad_loss])
|
||||||
|
grads = gradients.gradients(
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
grad_ys=grad_loss,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
colocate_gradients_with_ops=colocate_gradients_with_ops,
|
||||||
|
stop_gradients=stop_gradients)
|
||||||
|
|
||||||
|
grads_and_vars = list(zip(grads, var_list))
|
||||||
|
self._assert_valid_dtypes([
|
||||||
|
v for g, v in grads_and_vars
|
||||||
|
if g is not None and v.dtype != dtypes.resource
|
||||||
|
])
|
||||||
|
|
||||||
|
return grads_and_vars
|
||||||
|
|
||||||
|
def apply_gradients(self, grads_and_vars, name=None):
|
||||||
|
"""Apply gradients to variables.
|
||||||
|
|
||||||
|
This is the second part of `minimize()`. It returns an `Operation` that
|
||||||
|
applies gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads_and_vars: List of (gradient, variable) pairs as returned by
|
||||||
|
`compute_gradients()`.
|
||||||
|
name: Optional name for the returned operation. Default to the name
|
||||||
|
passed to the `Optimizer` constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `Operation` that applies the specified gradients. If `global_step`
|
||||||
|
was not None, that operation also increments `global_step`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `grads_and_vars` is malformed.
|
||||||
|
ValueError: If none of the variables have gradients.
|
||||||
|
"""
|
||||||
|
grads_and_vars = _filter_grads(grads_and_vars)
|
||||||
|
var_list = [v for (_, v) in grads_and_vars]
|
||||||
|
if distribution_strategy_context.has_distribution_strategy():
|
||||||
|
reduced_grads = merge_grads(grads_and_vars)
|
||||||
|
grads_and_vars = zip(reduced_grads, var_list)
|
||||||
|
|
||||||
|
with ops.init_scope():
|
||||||
|
self._create_slots(var_list)
|
||||||
|
update_ops = []
|
||||||
|
|
||||||
|
def update_grad_to_var(grad, var):
|
||||||
|
"""Apply gradient to variable."""
|
||||||
|
if isinstance(var, ops.Tensor):
|
||||||
|
raise NotImplementedError("Trying to update a Tensor ", var)
|
||||||
|
if isinstance(grad, ops.IndexedSlices):
|
||||||
|
if var.constraint is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot use a constraint function on a sparse variable.")
|
||||||
|
return self._resource_apply_sparse_duplicate_indices(
|
||||||
|
grad.values, var, grad.indices)
|
||||||
|
update_op = self._resource_apply_dense(grad, var)
|
||||||
|
if var.constraint is not None:
|
||||||
|
with ops.control_dependencies([update_op]):
|
||||||
|
return var.assign(var.constraint(var))
|
||||||
|
else:
|
||||||
|
return update_op
|
||||||
|
|
||||||
|
with ops.name_scope(name, self._name) as name:
|
||||||
|
self._prepare()
|
||||||
|
for grad, var in grads_and_vars:
|
||||||
|
scope_name = "" if in_eager_execution() else "_" + var.op.name
|
||||||
|
with ops.name_scope("update" + scope_name), ops.colocate_with(var):
|
||||||
|
update_ops.append(update_grad_to_var(grad, var))
|
||||||
|
with ops.colocate_with(self._iterations):
|
||||||
|
update_ops.append(self._iterations.assign_add(1))
|
||||||
|
return control_flow_ops.group(*update_ops)
|
||||||
|
|
||||||
def _set_hyper(self, name, value):
|
def _set_hyper(self, name, value):
|
||||||
self._hyper[name] = value
|
self._hyper[name] = value
|
||||||
@ -166,8 +329,33 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
value = self._hyper[name]
|
value = self._hyper[name]
|
||||||
return self._call_if_callable(value)
|
return self._call_if_callable(value)
|
||||||
|
|
||||||
|
def add_slot(self, var, slot_name):
|
||||||
|
slot_key = _get_slot_key_from_var(var, slot_name)
|
||||||
|
if slot_key not in self._slots:
|
||||||
|
self._slots[slot_key] = self.add_weight(
|
||||||
|
name=slot_key, shape=var.shape, dtype=var.dtype)
|
||||||
|
|
||||||
|
def get_slot(self, var, slot_name):
|
||||||
|
slot_key = _get_slot_key_from_var(var, slot_name)
|
||||||
|
return self._slots[slot_key]
|
||||||
|
|
||||||
def _prepare(self):
|
def _prepare(self):
|
||||||
pass
|
if self._prepared:
|
||||||
|
return
|
||||||
|
# This is where all hyper variables will be created.
|
||||||
|
with ops.device("cpu:0"):
|
||||||
|
self._iterations = self.add_weight(
|
||||||
|
self._name + "/iter",
|
||||||
|
shape=[],
|
||||||
|
trainable=False,
|
||||||
|
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
|
||||||
|
self._prepared = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iteration(self):
|
||||||
|
if not self._prepared:
|
||||||
|
self._prepare()
|
||||||
|
return self._iterations
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
@ -205,3 +393,116 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
def _serialize_hyperparameter(self, hyperparameter_name):
|
def _serialize_hyperparameter(self, hyperparameter_name):
|
||||||
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
|
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
|
||||||
return self._hyper[hyperparameter_name]
|
return self._hyper[hyperparameter_name]
|
||||||
|
|
||||||
|
def add_weight(self,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
dtype=None,
|
||||||
|
initializer="zeros",
|
||||||
|
trainable=None,
|
||||||
|
synchronization=variables.VariableSynchronization.AUTO,
|
||||||
|
aggregation=variables.VariableAggregation.NONE):
|
||||||
|
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtypes.float32
|
||||||
|
initializer = initializers.get(initializer)
|
||||||
|
|
||||||
|
if synchronization == variables.VariableSynchronization.ON_READ:
|
||||||
|
if trainable:
|
||||||
|
raise ValueError(
|
||||||
|
"Synchronization value can be set to "
|
||||||
|
"VariableSynchronization.ON_READ only for non-trainable variables. "
|
||||||
|
"You have specified trainable=True and "
|
||||||
|
"synchronization=VariableSynchronization.ON_READ.")
|
||||||
|
else:
|
||||||
|
# Set trainable to be false when variable is to be synced on read.
|
||||||
|
trainable = False
|
||||||
|
elif trainable is None:
|
||||||
|
trainable = True
|
||||||
|
|
||||||
|
variable = self._add_variable_with_custom_getter(
|
||||||
|
name=name,
|
||||||
|
shape=shape,
|
||||||
|
getter=base_layer.make_variable,
|
||||||
|
overwrite=True,
|
||||||
|
initializer=initializers.get(initializer),
|
||||||
|
dtype=dtype,
|
||||||
|
trainable=trainable,
|
||||||
|
use_resource=True,
|
||||||
|
synchronization=synchronization,
|
||||||
|
aggregation=aggregation)
|
||||||
|
|
||||||
|
return variable
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_grads(grads_and_vars):
|
||||||
|
"""Filter out iterable with grad equal to None."""
|
||||||
|
grads_and_vars = tuple(grads_and_vars)
|
||||||
|
if not grads_and_vars:
|
||||||
|
raise ValueError("No variables provided.")
|
||||||
|
filtered = []
|
||||||
|
vars_with_empty_grads = []
|
||||||
|
for grad, var in grads_and_vars:
|
||||||
|
if grad is None:
|
||||||
|
vars_with_empty_grads.append(var)
|
||||||
|
else:
|
||||||
|
filtered.append((grad, var))
|
||||||
|
filtered = tuple(filtered)
|
||||||
|
if not filtered:
|
||||||
|
raise ValueError("No gradients provided for any variable: %s." %
|
||||||
|
([v.name for _, v in filtered],))
|
||||||
|
if vars_with_empty_grads:
|
||||||
|
logging.warning(
|
||||||
|
("Gradients does not exist for variables %s when minimizing the loss."),
|
||||||
|
([v.name for v in vars_with_empty_grads]))
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def merge_grads(grads_and_vars):
|
||||||
|
"""Merge gradients from different replicas."""
|
||||||
|
|
||||||
|
def merge_grad_fn(strategy, grads_and_vars):
|
||||||
|
reduced_grads = strategy.batch_reduce(
|
||||||
|
variable_scope.VariableAggregation.MEAN, grads_and_vars)
|
||||||
|
return reduced_grads
|
||||||
|
|
||||||
|
return distribution_strategy_context.get_tower_context().merge_call(
|
||||||
|
merge_grad_fn, grads_and_vars)
|
||||||
|
|
||||||
|
|
||||||
|
def in_eager_execution():
|
||||||
|
with ops.init_scope():
|
||||||
|
return context.executing_eagerly()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_slot_key_from_var(var, slot_name):
|
||||||
|
"""Get the slot key for the variable.
|
||||||
|
|
||||||
|
Scope the slot name in the namespace of the primary variable.
|
||||||
|
Set "primary.op.name + '/' + slot_name" as default name.
|
||||||
|
|
||||||
|
In graph mode the name is derived from the op.
|
||||||
|
In eager mode the name is derived from the var.
|
||||||
|
If distribution strategy exists, then the name is derived from the primary
|
||||||
|
variable instead of replica variable, i.e., /dense/kernel instead of
|
||||||
|
/dense/kernel/replica_1. If the slot name is 'm', then the slot variables
|
||||||
|
being created are /dense/kernel/m and /dense/kernel/m/replica_1, instead of
|
||||||
|
/dense/kernel/replica_1/m/replica_1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var: the variable.
|
||||||
|
slot_name: the name of the slot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the name of the variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if distribution_strategy_context.has_distribution_strategy() and hasattr(
|
||||||
|
var, "_primary_var"):
|
||||||
|
var = var._primary_var
|
||||||
|
if context.executing_eagerly():
|
||||||
|
name = var._shared_name
|
||||||
|
else:
|
||||||
|
name = var.op.name
|
||||||
|
return name + "/" + slot_name
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -36,102 +37,89 @@ class OptimizerTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.cached_session():
|
||||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||||
def loss():
|
loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
||||||
return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
if not context.executing_eagerly():
|
||||||
# Note that for eager execution, minimize expects a function instead of a
|
loss = loss()
|
||||||
# Tensor.
|
sgd = gradient_descent.SGD(3.0)
|
||||||
global_step = resource_variable_ops.ResourceVariable(
|
|
||||||
array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Run 1 step of sgd through optimizer
|
# Run 1 step of sgd through optimizer
|
||||||
opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
|
opt_op = sgd.minimize(loss, var_list=[var0, var1])
|
||||||
|
self.evaluate(sgd.iteration.initializer)
|
||||||
self.evaluate(opt_op)
|
self.evaluate(opt_op)
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllClose([-14., -13.], self.evaluate(var0))
|
self.assertAllClose([-14., -13.], self.evaluate(var0))
|
||||||
self.assertAllClose([-6., -5.], self.evaluate(var1))
|
self.assertAllClose([-6., -5.], self.evaluate(var1))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testAggregationMethod(self):
|
def testAggregationMethod(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||||
cost = 5 * var0 + 3 * var1
|
loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
||||||
global_step = variables.Variable(
|
if not context.executing_eagerly():
|
||||||
array_ops.zeros([], dtypes.int64), name='global_step')
|
loss = loss()
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
sgd = gradient_descent.SGD(3.0)
|
||||||
opt_op = sgd_op.minimize(
|
|
||||||
cost,
|
|
||||||
global_step, [var0, var1],
|
|
||||||
aggregation_method=gradients_impl.AggregationMethod.
|
|
||||||
EXPERIMENTAL_ACCUMULATE_N)
|
|
||||||
|
|
||||||
variables.global_variables_initializer().run()
|
self.evaluate(variables.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Run 1 step of sgd through optimizer
|
# Run 1 step of sgd through optimizer
|
||||||
opt_op.run()
|
opt_op = sgd.minimize(
|
||||||
|
loss,
|
||||||
|
var_list=[var0, var1],
|
||||||
|
aggregation_method=gradients_impl.AggregationMethod
|
||||||
|
.EXPERIMENTAL_ACCUMULATE_N)
|
||||||
|
self.evaluate(sgd.iteration.initializer)
|
||||||
|
self.evaluate(opt_op)
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllClose([-14., -13.], var0.eval())
|
self.assertAllClose([-14., -13.], self.evaluate(var0))
|
||||||
self.assertAllClose([-6., -5.], var1.eval())
|
self.assertAllClose([-6., -5.], self.evaluate(var1))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testPrecomputedGradient(self):
|
def testPrecomputedGradient(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||||
cost = 5 * var0 + 3 * var1
|
loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
loss = loss()
|
||||||
grad_loss = constant_op.constant([42, -42], dtype=dtype)
|
grad_loss = constant_op.constant([42, -42], dtype=dtype)
|
||||||
global_step = variables.Variable(
|
sgd = gradient_descent.SGD(3.0)
|
||||||
array_ops.zeros([], dtypes.int64), name='global_step')
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
|
||||||
opt_op = sgd_op.minimize(
|
|
||||||
cost, global_step, [var0, var1], grad_loss=grad_loss)
|
|
||||||
|
|
||||||
variables.global_variables_initializer().run()
|
self.evaluate(variables.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Run 1 step of sgd through optimizer
|
# Run 1 step of sgd through optimizer
|
||||||
opt_op.run()
|
opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss)
|
||||||
|
self.evaluate(sgd.iteration.initializer)
|
||||||
|
self.evaluate(opt_op)
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
|
self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
|
||||||
var0.eval())
|
self.evaluate(var0))
|
||||||
self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
|
self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
|
||||||
var1.eval())
|
self.evaluate(var1))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testNoVariables(self):
|
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
|
||||||
# pylint: disable=cell-var-from-loop
|
|
||||||
def loss():
|
|
||||||
var0 = resource_variable_ops.ResourceVariable(
|
|
||||||
[1.0, 2.0], dtype=dtype, trainable=False, name='a')
|
|
||||||
var1 = resource_variable_ops.ResourceVariable(
|
|
||||||
[3.0, 4.0], dtype=dtype, trainable=False, name='b')
|
|
||||||
return 5 * var0 + var1
|
|
||||||
# pylint: enable=cell-var-from-loop
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'No.*variables'):
|
|
||||||
sgd_op.minimize(loss)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testNoGradients(self):
|
def testNoGradients(self):
|
||||||
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.cached_session():
|
||||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||||
# pylint: disable=cell-var-from-loop
|
loss = lambda: 5 * var0 # pylint: disable=cell-var-from-loop
|
||||||
def loss():
|
if not context.executing_eagerly():
|
||||||
return 5 * var0
|
loss = loss()
|
||||||
# pylint: enable=cell-var-from-loop
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
sgd_op = gradient_descent.SGD(3.0)
|
||||||
with self.assertRaisesRegexp(ValueError, 'No gradients'):
|
with self.assertRaisesRegexp(ValueError, 'No gradients'):
|
||||||
# var1 has no gradient
|
# var1 has no gradient
|
||||||
@ -140,10 +128,12 @@ class OptimizerTest(test.TestCase):
|
|||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testNoGradientsForAnyVariables_Minimize(self):
|
def testNoGradientsForAnyVariables_Minimize(self):
|
||||||
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.cached_session():
|
||||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||||
def loss():
|
loss = lambda: constant_op.constant(5.0)
|
||||||
return constant_op.constant(5.0)
|
if not context.executing_eagerly():
|
||||||
|
loss = loss()
|
||||||
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
sgd_op = gradient_descent.SGD(3.0)
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
@ -153,6 +143,7 @@ class OptimizerTest(test.TestCase):
|
|||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testNoGradientsForAnyVariables_ApplyGradients(self):
|
def testNoGradientsForAnyVariables_ApplyGradients(self):
|
||||||
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.cached_session():
|
||||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
sgd_op = gradient_descent.SGD(3.0)
|
||||||
@ -163,17 +154,19 @@ class OptimizerTest(test.TestCase):
|
|||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testGradientsAsVariables(self):
|
def testGradientsAsVariables(self):
|
||||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.cached_session():
|
||||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||||
def loss():
|
loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
||||||
return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
|
if not context.executing_eagerly():
|
||||||
|
loss = loss()
|
||||||
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
sgd = gradient_descent.SGD(3.0)
|
||||||
grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
|
grads_and_vars = sgd.compute_gradients(loss, [var0, var1])
|
||||||
# Convert gradients to tf.Variables
|
# Convert gradients to tf.Variables
|
||||||
converted_grads = [
|
converted_grads = [
|
||||||
resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
|
resource_variable_ops.ResourceVariable(
|
||||||
name='c_%d_%d' % (i, j))
|
array_ops.zeros([2], dtype), name='c_%d_%d' % (i, j))
|
||||||
for j, gv in enumerate(grads_and_vars)
|
for j, gv in enumerate(grads_and_vars)
|
||||||
]
|
]
|
||||||
convert_ops = [
|
convert_ops = [
|
||||||
@ -181,8 +174,8 @@ class OptimizerTest(test.TestCase):
|
|||||||
for j, gv in enumerate(grads_and_vars)
|
for j, gv in enumerate(grads_and_vars)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Run convert_ops to achieve the gradients converting
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
# Run convert_ops to achieve the gradietns converting
|
|
||||||
self.evaluate(convert_ops)
|
self.evaluate(convert_ops)
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
@ -190,7 +183,8 @@ class OptimizerTest(test.TestCase):
|
|||||||
|
|
||||||
# Run 1 step of sgd through optimizer
|
# Run 1 step of sgd through optimizer
|
||||||
converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
|
converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
|
||||||
opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
|
opt_op = sgd.apply_gradients(converted_grads_and_vars)
|
||||||
|
self.evaluate(sgd.iteration.initializer)
|
||||||
self.evaluate(opt_op)
|
self.evaluate(opt_op)
|
||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
@ -199,31 +193,23 @@ class OptimizerTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testComputeGradientsWithTensors(self):
|
def testComputeGradientsWithTensors(self):
|
||||||
|
with self.cached_session():
|
||||||
x = ops.convert_to_tensor(1.0)
|
x = ops.convert_to_tensor(1.0)
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
sgd = gradient_descent.SGD(3.0)
|
||||||
grads_and_vars = sgd_op.compute_gradients(f, [x])
|
grads_and_vars = sgd.compute_gradients(f, [x])
|
||||||
self.assertEqual(1, len(grads_and_vars))
|
self.assertEqual(1, len(grads_and_vars))
|
||||||
grad, x_as_var = grads_and_vars[0]
|
grad, x_as_var = grads_and_vars[0]
|
||||||
self.assertIs(x, x_as_var)
|
self.assertIs(x, x_as_var)
|
||||||
self.assertEqual(2.0, self.evaluate(grad))
|
self.assertEqual(2.0, self.evaluate(grad))
|
||||||
|
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
sgd_op.apply_gradients(grads_and_vars)
|
sgd.apply_gradients(grads_and_vars)
|
||||||
|
|
||||||
def testTrainOp(self):
|
|
||||||
with self.cached_session():
|
|
||||||
var0 = variables.Variable([1.0, 2.0])
|
|
||||||
var1 = variables.Variable([3.0, 4.0])
|
|
||||||
cost = 5 * var0 + 3 * var1
|
|
||||||
global_step = variables.Variable(
|
|
||||||
array_ops.zeros([], dtypes.int64), name='global_step')
|
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
|
||||||
opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
|
|
||||||
self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
|
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testConstraint(self):
|
def testConstraint(self):
|
||||||
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
|
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
|
||||||
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
|
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
|
||||||
@ -232,21 +218,29 @@ class OptimizerTest(test.TestCase):
|
|||||||
constraint=constraint_01)
|
constraint=constraint_01)
|
||||||
var1 = variables.Variable([3.0, 4.0],
|
var1 = variables.Variable([3.0, 4.0],
|
||||||
constraint=constraint_0)
|
constraint=constraint_0)
|
||||||
cost = 5 * var0 + 3 * var1
|
loss = lambda: 5 * var0 + 3 * var1
|
||||||
global_step = variables.Variable(
|
if not context.executing_eagerly(): # pylint: disable=cell-var-from-loop
|
||||||
array_ops.zeros([], dtypes.int64), name='global_step')
|
loss = loss()
|
||||||
sgd_op = gradient_descent.SGD(3.0)
|
sgd = gradient_descent.SGD(3.0)
|
||||||
opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
|
|
||||||
|
|
||||||
variables.global_variables_initializer().run()
|
self.evaluate(variables.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Run 1 step of sgd through optimizer
|
# Run 1 step of sgd through optimizer
|
||||||
opt_op.run()
|
opt_op = sgd.minimize(loss, var_list=[var0, var1])
|
||||||
|
self.evaluate(sgd.iteration.initializer)
|
||||||
|
self.evaluate(opt_op)
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllClose([-0.1, -0.1], var0.eval())
|
self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
|
||||||
self.assertAllClose([0., 0.], var1.eval())
|
self.assertAllClose([0., 0.], self.evaluate(var1))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testIterationWithoutMinimize(self):
|
||||||
|
with self.cached_session():
|
||||||
|
sgd = gradient_descent.SGD(3.0)
|
||||||
|
self.evaluate(sgd.iteration.initializer)
|
||||||
|
self.assertEqual(0, self.evaluate(sgd.iteration))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user