diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 1fba094059e..daa7592bcfb 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -471,6 +471,28 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "keras_optimizer_v2_test",
+ srcs = ["keras_optimizer_v2_test.py"],
+ additional_deps = [
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/optimizer_v2:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
+
cuda_py_test(
name = "estimator_training_test",
size = "large",
diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
new file mode 100644
index 00000000000..f4c222f26c3
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
@@ -0,0 +1,237 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that show that DistributionStrategy works with canned Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+from absl.testing import parameterized
+import numpy as np
+import six
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator import training
+from tensorflow.python.estimator.canned import dnn_linear_combined
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+
+
+class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def dataset_input_fn(self, x, y, batch_size):
+
+ def input_fn():
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(1).batch(batch_size)
+ return dataset
+
+ return input_fn
+
+ @combinations.generate(
+ combinations.combine(
+ mode=['graph'],
+ distribution=[
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus
+ ],
+ use_train_and_evaluate=[True, False]))
+ def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate):
+ label_dimension = 2
+ input_dimension = label_dimension
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ train_input_fn = self.dataset_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size // len(distribution.worker_devices))
+ eval_input_fn = self.dataset_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size // len(distribution.worker_devices))
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, batch_size=batch_size, shuffle=False)
+
+ linear_feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ feature_columns = linear_feature_columns + dnn_feature_columns
+ session_config = config_pb2.ConfigProto(
+ log_device_placement=True, allow_soft_placement=True)
+ estimator = dnn_linear_combined.DNNLinearCombinedRegressor(
+ linear_feature_columns=linear_feature_columns,
+ dnn_hidden_units=(2, 2),
+ dnn_feature_columns=dnn_feature_columns,
+ label_dimension=label_dimension,
+ model_dir=self._model_dir,
+ dnn_optimizer=adam.Adam(0.001),
+ linear_optimizer=adam.Adam(0.001),
+ config=run_config.RunConfig(
+ train_distribute=distribution,
+ eval_distribute=distribution,
+ session_config=session_config))
+
+ num_steps = 2
+ if use_train_and_evaluate:
+ scores, _ = training.train_and_evaluate(
+ estimator, training.TrainSpec(train_input_fn, max_steps=num_steps),
+ training.EvalSpec(eval_input_fn))
+ else:
+ estimator.train(train_input_fn, steps=num_steps)
+ scores = estimator.evaluate(eval_input_fn)
+
+ self.assertIn('loss', six.iterkeys(scores))
+
+ predictions = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in estimator.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, label_dimension), predictions.shape)
+
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+
+class MirroredStrategyOptimizerV2Test(test.TestCase):
+
+ def testKerasOptimizerWithUnequalInput(self):
+ if context.num_gpus() < 1:
+ self.skipTest('Not enough GPUs.')
+
+ def create_fn(device_id):
+ var = variables.Variable(
+ 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM)
+ # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5.
+ loss = (device_id + 1) * var
+ optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2)
+ train_op = optimizer.minimize(loss, var_list=[var])
+ m = optimizer.get_slot(var, 'm')
+ v = optimizer.get_slot(var, 'v')
+ return (var, m, v, train_op, optimizer.iteration)
+
+ devices = ['/device:GPU:0', '/device:CPU:0']
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ (var, m, v, op, counter) = dist.call_for_each_replica(
+ create_fn, dist.worker_device_index, run_concurrently=False)
+ self.evaluate(variables.global_variables_initializer())
+ var_val = [2.0, 2.0, 2.0]
+ self.assertAllClose(
+ var_val,
+ self.evaluate(
+ [dist.read_var(var),
+ var.get(devices[0]),
+ var.get(devices[1])]))
+ self.assertAllClose([0, 0, 0],
+ self.evaluate([
+ dist.read_var(counter),
+ counter.get(devices[0]),
+ counter.get(devices[1])
+ ]))
+
+ train_op = dist.unwrap(op)
+ self.evaluate(train_op)
+ # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2
+ m_val = [1.2, 1.2, 1.2]
+ # assert slot variables in both replicas are the same.
+ self.assertAllClose(
+ m_val,
+ self.evaluate(
+ [dist.read_var(m),
+ m.get(devices[0]),
+ m.get(devices[1])]))
+ # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25
+ v_val = [1.8, 1.8, 1.8]
+ self.assertAllClose(
+ v_val,
+ self.evaluate(
+ [dist.read_var(v),
+ v.get(devices[0]),
+ v.get(devices[1])]))
+ # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1)
+ # = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8
+ var_val = [1.99, 1.99, 1.99]
+ self.assertAllClose(
+ var_val,
+ self.evaluate(
+ [dist.read_var(var),
+ var.get(devices[0]),
+ var.get(devices[1])]))
+ self.assertAllClose([1, 1, 1],
+ self.evaluate([
+ dist.read_var(counter),
+ counter.get(devices[0]),
+ counter.get(devices[1])
+ ]))
+
+ self.evaluate(train_op)
+ # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5
+ m_val = [1.44, 1.44, 1.44]
+ self.assertAllClose(
+ m_val,
+ self.evaluate(
+ [dist.read_var(m),
+ m.get(devices[0]),
+ m.get(devices[1])]))
+ # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25
+ v_val = [2.16, 2.16, 2.16]
+ self.assertAllClose(
+ v_val,
+ self.evaluate(
+ [dist.read_var(v),
+ v.get(devices[0]),
+ v.get(devices[1])]))
+ self.assertAllClose([2, 2, 2],
+ self.evaluate([
+ dist.read_var(counter),
+ counter.get(devices[0]),
+ counter.get(devices[1])
+ ]))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD
index e21674ef606..f742e8aa265 100644
--- a/tensorflow/python/keras/optimizer_v2/BUILD
+++ b/tensorflow/python/keras/optimizer_v2/BUILD
@@ -12,6 +12,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "optimizer_v2",
srcs = [
+ "adam.py",
"gradient_descent.py",
"optimizer_v2.py",
],
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
new file mode 100644
index 00000000000..6c67cd3a61a
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -0,0 +1,124 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Adam for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import training_ops
+
+
+class Adam(optimizer_v2.OptimizerV2):
+ """Optimizer that implements the Adam algorithm.
+
+ Adam optimization is a stochastic gradient descent method that is based on
+ adaptive estimation of first-order and second-order moments. According to the
+ reference, the method is 'computationally efficient, has little memory
+ requirement, invariant to diagonal rescaling of gradients, and is well suited
+ for problems that are large in terms of data/parameters'.
+
+ # References
+ See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+ ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-8,
+ name='Adam'):
+ r"""Construct a new Adam optimizer.
+
+ Initialization:
+
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
+
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ $$t := t + 1$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
+ formulation just before Section 2.1 of the Kingma and Ba paper rather than
+ the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
+ hat" in the paper.
+
+ The sparse implementation of this algorithm (used when the gradient is an
+ IndexedSlices object, typically because of `tf.gather` or an embedding
+ lookup in the forward pass) does apply momentum to variable slices even if
+ they were not used in the forward pass (meaning they have a gradient equal
+ to zero). Momentum decay (beta1) is also applied to the entire momentum
+ accumulator. This means that the sparse behavior is equivalent to the dense
+ behavior (in contrast to some momentum implementations which ignore momentum
+ unless a variable slice was actually used).
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta_1: A float value or a constant float tensor. The exponential decay
+ rate for the 1st moment estimates.
+ beta_2: A float value or a constant float tensor. The exponential decay
+ rate for the 2nd moment estimates.
+ epsilon: A small constant for numerical stability. This epsilon is
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
+ Section 2.1), not the epsilon in Algorithm 1 of the paper.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam". @compatibility(eager) When eager execution is
+ enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be
+ a callable that takes no arguments and returns the actual value to use.
+ This can be useful for changing these values across different
+ invocations of optimizer functions. @end_compatibility
+ """
+
+ super(Adam, self).__init__(name)
+ self._lr = learning_rate
+ self._beta_1 = beta_1
+ self._beta_2 = beta_2
+ self._epsilon = epsilon
+
+ def _create_slots(self, var_list):
+ # Create slots for the first and second moments.
+ for var in var_list:
+ self.add_slot(var, 'm')
+ self.add_slot(var, 'v')
+
+ def _resource_apply_dense(self, grad, var):
+ m = self.get_slot(var, 'm')
+ v = self.get_slot(var, 'v')
+ # TODO(tanzheny): let optimizer have its own step counter, and let
+ # beta1_power and beta2_power depend on it.
+ return training_ops.resource_apply_adam(
+ var.handle,
+ m.handle,
+ v.handle,
+ math_ops.cast(self._beta_1, grad.dtype.base_dtype),
+ math_ops.cast(self._beta_2, grad.dtype.base_dtype),
+ math_ops.cast(self._lr, grad.dtype.base_dtype),
+ math_ops.cast(self._beta_1, grad.dtype.base_dtype),
+ math_ops.cast(self._beta_2, grad.dtype.base_dtype),
+ math_ops.cast(self._epsilon, grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
index 08b273f5562..a1f534d55f6 100644
--- a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
+++ b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
@@ -64,13 +64,13 @@ class GradientDescentOptimizerTest(test.TestCase):
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- sgd_op = gradient_descent.SGD(3.0).apply_gradients(
- zip([grads0, grads1], [var0, var1]))
+ sgd = gradient_descent.SGD(3.0)
+ sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
# TODO(apassos) calling initialize_resources on all resources here
# doesn't work because the sessions and graph are reused across unit
# tests and this would mean trying to reinitialize variables. Figure out
# a long-term solution for this.
- resources.initialize_resources([var0, var1]).run()
+ resources.initialize_resources([var0, var1, sgd.iteration]).run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
@@ -90,13 +90,13 @@ class GradientDescentOptimizerTest(test.TestCase):
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
lr = lambda: 3.0
- sgd_op = gradient_descent.SGD(lr).apply_gradients(
- zip([grads0, grads1], [var0, var1]))
+ sgd = gradient_descent.SGD(lr)
+ sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
# TODO(apassos) calling initialize_resources on all resources here
# doesn't work because the sessions and graph are reused across unit
# tests and this would mean trying to reinitialize variables. Figure out
# a long-term solution for this.
- resources.initialize_resources([var0, var1]).run()
+ resources.initialize_resources([var0, var1, sgd.iteration]).run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
@@ -116,12 +116,13 @@ class GradientDescentOptimizerTest(test.TestCase):
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(var0, x) + var1
loss = pred * pred
- sgd_op = gradient_descent.SGD(1.0).minimize(loss)
+ sgd = gradient_descent.SGD(1.0)
+ sgd_op = sgd.minimize(loss, [var0, var1])
# TODO(apassos) calling initialize_resources on all resources here
# doesn't work because the sessions and graph are reused across unit
# tests and this would mean trying to reinitialize variables. Figure out
# a long-term solution for this.
- resources.initialize_resources([var0, var1]).run()
+ resources.initialize_resources([var0, var1, sgd.iteration]).run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
self.assertAllCloseAccordingToType([3.0], var1.eval())
@@ -143,7 +144,7 @@ class GradientDescentOptimizerTest(test.TestCase):
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
pred += var1
loss = pred * pred
- sgd_op = gradient_descent.SGD(1.0).minimize(loss)
+ sgd_op = gradient_descent.SGD(1.0).minimize(loss, [var0, var1])
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
@@ -193,25 +194,24 @@ class GradientDescentOptimizerTest(test.TestCase):
def testWithGlobalStep(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
- global_step = variables.Variable(0, trainable=False)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- sgd_op = gradient_descent.SGD(3.0).apply_gradients(
- zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ sgd = gradient_descent.SGD(3.0)
+ sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
# Run 1 step of sgd
sgd_op.run()
- # Validate updated params and global_step
+ # Validate updated params and optimizer iterations.
self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
var0.eval())
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
var1.eval())
- self.assertAllCloseAccordingToType(1, global_step.eval())
+ self.assertAllCloseAccordingToType(1, sgd.iteration.eval())
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index fd69cd0c664..c820847e53d 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,20 @@ from __future__ import print_function
import abc
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as optimizer_v1
+from tensorflow.python.util import nest
class OptimizerV2(optimizer_v1.Optimizer):
@@ -77,29 +90,6 @@ class OptimizerV2(optimizer_v1.Optimizer):
opt.apply_gradients(capped_grads_and_vars)
```
- ### Gating Gradients
-
- Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
- argument that controls the degree of parallelism during the application of
- the gradients.
-
- The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
-
- `GATE_NONE`: Compute and apply gradients in parallel. This provides
- the maximum parallelism in execution, at the cost of some non-reproducibility
- in the results. For example the two gradients of `matmul` depend on the input
- values: With `GATE_NONE` one of the gradients could be applied to one of the
- inputs _before_ the other gradient is computed resulting in non-reproducible
- results.
-
- `GATE_OP`: For each Op, make sure all gradients are computed before
- they are used. This prevents race conditions for Ops that generate gradients
- for multiple inputs where the gradients depend on the inputs.
-
- `GATE_GRAPH`: Make sure all gradients for all variables are computed
- before any one of them is used. This provides the least parallelism but can
- be useful if you want to process all gradients before applying any of them.
-
### Slots
Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
@@ -111,11 +101,6 @@ class OptimizerV2(optimizer_v1.Optimizer):
This can be useful if you want to log debug a training algorithm, report stats
about the slots, etc.
- ### Non-slot variables
-
- Some optimizer subclasses, such as `AdamOptimizer` have variables that
- are not associated with the variables to train, just the step itself.
-
### Hyper parameters
These are arguments passed to the optimizer subclass constructor
@@ -124,18 +109,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
callables. If they are callable, the callable will be called during
`apply_gradients()` to get the value for the hyper parameter.
- ### State
-
- Internal methods are passed a `state` argument with the correct
- values to use for the slot and non-slot variables, and the hyper
- parameters.
"""
- # Values for gate_gradients.
- GATE_NONE = 0
- GATE_OP = 1
- GATE_GRAPH = 2
-
def __init__(self, name):
"""Create a new Optimizer.
@@ -145,6 +120,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
you should be able to use the _set_hyper()/state.get_hyper()
facility instead.
+ This class in stateful and thread-compatible.
+
Args:
name: A non-empty string. The name to use for accumulators created
for the optimizer.
@@ -157,6 +134,192 @@ class OptimizerV2(optimizer_v1.Optimizer):
self._use_locking = True
super(OptimizerV2, self).__init__(self._use_locking, name)
self._hyper = {}
+ self._slots = {}
+ self._prepared = False
+
+ def minimize(self,
+ loss,
+ var_list,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Add operations to minimize `loss` by updating `var_list`.
+
+ This method simply combines calls `compute_gradients()` and
+ `apply_gradients()`. If you want to process the gradient before applying
+ them call `compute_gradients()` and `apply_gradients()` explicitly instead
+ of using this function.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ var_list: list or tuple of `Variable` objects to update to minimize
+ `loss`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with the
+ corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
+ Raises:
+ ValueError: If some of the variables are not `Variable` objects.
+
+ @compatibility(eager)
+ When eager execution is enabled, `loss` should be a Python function that
+ takes no arguments and computes the value to be minimized. Minimization (and
+ gradient computation) is done with respect to the elements of `var_list` if
+ not None, else with respect to any trainable variables created during the
+ execution of the `loss` function. `gate_gradients`, `aggregation_method`,
+ `colocate_gradients_with_ops` and `grad_loss` are ignored when eager
+ execution is enabled.
+ @end_compatibility
+ """
+ grads_and_vars = self.compute_gradients(
+ loss,
+ var_list=var_list,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+
+ return self.apply_gradients(grads_and_vars, name=name)
+
+ def compute_gradients(self,
+ loss,
+ var_list,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None,
+ stop_gradients=None):
+ """Compute gradients of `loss` for the variables in `var_list`.
+
+ This is the first part of `minimize()`. It returns a list
+ of (gradient, variable) pairs where "gradient" is the gradient
+ for "variable". Note that "gradient" can be a `Tensor`, an
+ `IndexedSlices`, or `None` if there is no gradient for the
+ given variable.
+
+ Args:
+ loss: A Tensor containing the value to minimize or a callable taking no
+ arguments which returns the value to minimize. When eager execution is
+ enabled it must be a callable.
+ var_list: Optional list or tuple of `tf.Variable` to update to minimize
+ `loss`. Defaults to the list of variables collected in the graph under
+ the key `GraphKeys.TRAINABLE_VARIABLES`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with the
+ corresponding op.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+
+ Returns:
+ A list of (gradient, variable) pairs. Variable is always present, but
+ gradient can be `None`.
+
+ Raises:
+ TypeError: If `var_list` contains anything else than `Variable` objects.
+ ValueError: If some arguments are invalid, or var_list is None.
+ RuntimeError: If called with eager execution enabled and `loss` is
+ not callable.
+
+ @compatibility(eager)
+ When eager execution is enabled, `aggregation_method`, and
+ `colocate_gradients_with_ops` are ignored.
+ @end_compatibility
+ """
+ var_list = nest.flatten(var_list)
+ # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+ if callable(loss):
+ with backprop.GradientTape() as tape:
+ tape.watch(var_list)
+ loss_value = loss()
+ grads = tape.gradient(loss_value, var_list, grad_loss)
+ else:
+ if context.executing_eagerly():
+ raise RuntimeError("`loss` passed to Optimizer.compute_gradients "
+ "should be a function when eager execution is "
+ "enabled.")
+ self._assert_valid_dtypes([loss])
+ if grad_loss is not None:
+ self._assert_valid_dtypes([grad_loss])
+ grads = gradients.gradients(
+ loss,
+ var_list,
+ grad_ys=grad_loss,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ stop_gradients=stop_gradients)
+
+ grads_and_vars = list(zip(grads, var_list))
+ self._assert_valid_dtypes([
+ v for g, v in grads_and_vars
+ if g is not None and v.dtype != dtypes.resource
+ ])
+
+ return grads_and_vars
+
+ def apply_gradients(self, grads_and_vars, name=None):
+ """Apply gradients to variables.
+
+ This is the second part of `minimize()`. It returns an `Operation` that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ name: Optional name for the returned operation. Default to the name
+ passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+
+ Raises:
+ TypeError: If `grads_and_vars` is malformed.
+ ValueError: If none of the variables have gradients.
+ """
+ grads_and_vars = _filter_grads(grads_and_vars)
+ var_list = [v for (_, v) in grads_and_vars]
+ if distribution_strategy_context.has_distribution_strategy():
+ reduced_grads = merge_grads(grads_and_vars)
+ grads_and_vars = zip(reduced_grads, var_list)
+
+ with ops.init_scope():
+ self._create_slots(var_list)
+ update_ops = []
+
+ def update_grad_to_var(grad, var):
+ """Apply gradient to variable."""
+ if isinstance(var, ops.Tensor):
+ raise NotImplementedError("Trying to update a Tensor ", var)
+ if isinstance(grad, ops.IndexedSlices):
+ if var.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ return self._resource_apply_sparse_duplicate_indices(
+ grad.values, var, grad.indices)
+ update_op = self._resource_apply_dense(grad, var)
+ if var.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return var.assign(var.constraint(var))
+ else:
+ return update_op
+
+ with ops.name_scope(name, self._name) as name:
+ self._prepare()
+ for grad, var in grads_and_vars:
+ scope_name = "" if in_eager_execution() else "_" + var.op.name
+ with ops.name_scope("update" + scope_name), ops.colocate_with(var):
+ update_ops.append(update_grad_to_var(grad, var))
+ with ops.colocate_with(self._iterations):
+ update_ops.append(self._iterations.assign_add(1))
+ return control_flow_ops.group(*update_ops)
def _set_hyper(self, name, value):
self._hyper[name] = value
@@ -166,8 +329,33 @@ class OptimizerV2(optimizer_v1.Optimizer):
value = self._hyper[name]
return self._call_if_callable(value)
+ def add_slot(self, var, slot_name):
+ slot_key = _get_slot_key_from_var(var, slot_name)
+ if slot_key not in self._slots:
+ self._slots[slot_key] = self.add_weight(
+ name=slot_key, shape=var.shape, dtype=var.dtype)
+
+ def get_slot(self, var, slot_name):
+ slot_key = _get_slot_key_from_var(var, slot_name)
+ return self._slots[slot_key]
+
def _prepare(self):
- pass
+ if self._prepared:
+ return
+ # This is where all hyper variables will be created.
+ with ops.device("cpu:0"):
+ self._iterations = self.add_weight(
+ self._name + "/iter",
+ shape=[],
+ trainable=False,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
+ self._prepared = True
+
+ @property
+ def iteration(self):
+ if not self._prepared:
+ self._prepare()
+ return self._iterations
@abc.abstractmethod
def get_config(self):
@@ -205,3 +393,116 @@ class OptimizerV2(optimizer_v1.Optimizer):
def _serialize_hyperparameter(self, hyperparameter_name):
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
return self._hyper[hyperparameter_name]
+
+ def add_weight(self,
+ name,
+ shape,
+ dtype=None,
+ initializer="zeros",
+ trainable=None,
+ synchronization=variables.VariableSynchronization.AUTO,
+ aggregation=variables.VariableAggregation.NONE):
+
+ if dtype is None:
+ dtype = dtypes.float32
+ initializer = initializers.get(initializer)
+
+ if synchronization == variables.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ.")
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+
+ variable = self._add_variable_with_custom_getter(
+ name=name,
+ shape=shape,
+ getter=base_layer.make_variable,
+ overwrite=True,
+ initializer=initializers.get(initializer),
+ dtype=dtype,
+ trainable=trainable,
+ use_resource=True,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ return variable
+
+
+def _filter_grads(grads_and_vars):
+ """Filter out iterable with grad equal to None."""
+ grads_and_vars = tuple(grads_and_vars)
+ if not grads_and_vars:
+ raise ValueError("No variables provided.")
+ filtered = []
+ vars_with_empty_grads = []
+ for grad, var in grads_and_vars:
+ if grad is None:
+ vars_with_empty_grads.append(var)
+ else:
+ filtered.append((grad, var))
+ filtered = tuple(filtered)
+ if not filtered:
+ raise ValueError("No gradients provided for any variable: %s." %
+ ([v.name for _, v in filtered],))
+ if vars_with_empty_grads:
+ logging.warning(
+ ("Gradients does not exist for variables %s when minimizing the loss."),
+ ([v.name for v in vars_with_empty_grads]))
+ return filtered
+
+
+def merge_grads(grads_and_vars):
+ """Merge gradients from different replicas."""
+
+ def merge_grad_fn(strategy, grads_and_vars):
+ reduced_grads = strategy.batch_reduce(
+ variable_scope.VariableAggregation.MEAN, grads_and_vars)
+ return reduced_grads
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_grad_fn, grads_and_vars)
+
+
+def in_eager_execution():
+ with ops.init_scope():
+ return context.executing_eagerly()
+
+
+def _get_slot_key_from_var(var, slot_name):
+ """Get the slot key for the variable.
+
+ Scope the slot name in the namespace of the primary variable.
+ Set "primary.op.name + '/' + slot_name" as default name.
+
+ In graph mode the name is derived from the op.
+ In eager mode the name is derived from the var.
+ If distribution strategy exists, then the name is derived from the primary
+ variable instead of replica variable, i.e., /dense/kernel instead of
+ /dense/kernel/replica_1. If the slot name is 'm', then the slot variables
+ being created are /dense/kernel/m and /dense/kernel/m/replica_1, instead of
+ /dense/kernel/replica_1/m/replica_1.
+
+ Args:
+ var: the variable.
+ slot_name: the name of the slot.
+
+ Returns:
+ the name of the variable.
+ """
+
+ # pylint: disable=protected-access
+ if distribution_strategy_context.has_distribution_strategy() and hasattr(
+ var, "_primary_var"):
+ var = var._primary_var
+ if context.executing_eagerly():
+ name = var._shared_name
+ else:
+ name = var.op.name
+ return name + "/" + slot_name
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
index 15f33eb7baa..fe12ab204f1 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -36,194 +37,179 @@ class OptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBasic(self):
- for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- def loss():
- return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
- # Note that for eager execution, minimize expects a function instead of a
- # Tensor.
- global_step = resource_variable_ops.ResourceVariable(
- array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
- sgd_op = gradient_descent.SGD(3.0)
+ for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ if not context.executing_eagerly():
+ loss = loss()
+ sgd = gradient_descent.SGD(3.0)
- self.evaluate(variables.global_variables_initializer())
- # Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], self.evaluate(var0))
- self.assertAllClose([3.0, 4.0], self.evaluate(var1))
- # Run 1 step of sgd through optimizer
- opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
- self.evaluate(opt_op)
- # Validate updated params
- self.assertAllClose([-14., -13.], self.evaluate(var0))
- self.assertAllClose([-6., -5.], self.evaluate(var1))
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Run 1 step of sgd through optimizer
+ opt_op = sgd.minimize(loss, var_list=[var0, var1])
+ self.evaluate(sgd.iteration.initializer)
+ self.evaluate(opt_op)
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+ @test_util.run_in_graph_and_eager_modes
def testAggregationMethod(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
- cost = 5 * var0 + 3 * var1
- global_step = variables.Variable(
- array_ops.zeros([], dtypes.int64), name='global_step')
- sgd_op = gradient_descent.SGD(3.0)
- opt_op = sgd_op.minimize(
- cost,
- global_step, [var0, var1],
- aggregation_method=gradients_impl.AggregationMethod.
- EXPERIMENTAL_ACCUMULATE_N)
+ loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ if not context.executing_eagerly():
+ loss = loss()
+ sgd = gradient_descent.SGD(3.0)
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 1 step of sgd through optimizer
- opt_op.run()
+ opt_op = sgd.minimize(
+ loss,
+ var_list=[var0, var1],
+ aggregation_method=gradients_impl.AggregationMethod
+ .EXPERIMENTAL_ACCUMULATE_N)
+ self.evaluate(sgd.iteration.initializer)
+ self.evaluate(opt_op)
# Validate updated params
- self.assertAllClose([-14., -13.], var0.eval())
- self.assertAllClose([-6., -5.], var1.eval())
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+ @test_util.run_in_graph_and_eager_modes
def testPrecomputedGradient(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
- cost = 5 * var0 + 3 * var1
+ loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ if not context.executing_eagerly():
+ loss = loss()
grad_loss = constant_op.constant([42, -42], dtype=dtype)
- global_step = variables.Variable(
- array_ops.zeros([], dtypes.int64), name='global_step')
- sgd_op = gradient_descent.SGD(3.0)
- opt_op = sgd_op.minimize(
- cost, global_step, [var0, var1], grad_loss=grad_loss)
+ sgd = gradient_descent.SGD(3.0)
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 1 step of sgd through optimizer
- opt_op.run()
+ opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss)
+ self.evaluate(sgd.iteration.initializer)
+ self.evaluate(opt_op)
# Validate updated params
self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
- var0.eval())
+ self.evaluate(var0))
self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
- var1.eval())
-
- @test_util.run_in_graph_and_eager_modes
- def testNoVariables(self):
- for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- # pylint: disable=cell-var-from-loop
- def loss():
- var0 = resource_variable_ops.ResourceVariable(
- [1.0, 2.0], dtype=dtype, trainable=False, name='a')
- var1 = resource_variable_ops.ResourceVariable(
- [3.0, 4.0], dtype=dtype, trainable=False, name='b')
- return 5 * var0 + var1
- # pylint: enable=cell-var-from-loop
- sgd_op = gradient_descent.SGD(3.0)
- with self.assertRaisesRegexp(ValueError, 'No.*variables'):
- sgd_op.minimize(loss)
+ self.evaluate(var1))
@test_util.run_in_graph_and_eager_modes
def testNoGradients(self):
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- # pylint: disable=cell-var-from-loop
- def loss():
- return 5 * var0
- # pylint: enable=cell-var-from-loop
- sgd_op = gradient_descent.SGD(3.0)
- with self.assertRaisesRegexp(ValueError, 'No gradients'):
- # var1 has no gradient
- sgd_op.minimize(loss, var_list=[var1])
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ loss = lambda: 5 * var0 # pylint: disable=cell-var-from-loop
+ if not context.executing_eagerly():
+ loss = loss()
+ sgd_op = gradient_descent.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No gradients'):
+ # var1 has no gradient
+ sgd_op.minimize(loss, var_list=[var1])
@test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_Minimize(self):
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- def loss():
- return constant_op.constant(5.0)
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ loss = lambda: constant_op.constant(5.0)
+ if not context.executing_eagerly():
+ loss = loss()
- sgd_op = gradient_descent.SGD(3.0)
- with self.assertRaisesRegexp(ValueError,
- 'No gradients provided for any variable'):
- sgd_op.minimize(loss, var_list=[var0, var1])
+ sgd_op = gradient_descent.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.minimize(loss, var_list=[var0, var1])
@test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_ApplyGradients(self):
for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- sgd_op = gradient_descent.SGD(3.0)
- with self.assertRaisesRegexp(ValueError,
- 'No gradients provided for any variable'):
- sgd_op.apply_gradients([(None, var0), (None, var1)])
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ sgd_op = gradient_descent.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.apply_gradients([(None, var0), (None, var1)])
@test_util.run_in_graph_and_eager_modes
def testGradientsAsVariables(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- def loss():
- return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ if not context.executing_eagerly():
+ loss = loss()
- sgd_op = gradient_descent.SGD(3.0)
- grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
- # Convert gradients to tf.Variables
- converted_grads = [
- resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
- name='c_%d_%d' % (i, j))
- for j, gv in enumerate(grads_and_vars)
- ]
- convert_ops = [
- state_ops.assign(converted_grads[j], gv[0])
- for j, gv in enumerate(grads_and_vars)
- ]
+ sgd = gradient_descent.SGD(3.0)
+ grads_and_vars = sgd.compute_gradients(loss, [var0, var1])
+ # Convert gradients to tf.Variables
+ converted_grads = [
+ resource_variable_ops.ResourceVariable(
+ array_ops.zeros([2], dtype), name='c_%d_%d' % (i, j))
+ for j, gv in enumerate(grads_and_vars)
+ ]
+ convert_ops = [
+ state_ops.assign(converted_grads[j], gv[0])
+ for j, gv in enumerate(grads_and_vars)
+ ]
- self.evaluate(variables.global_variables_initializer())
- # Run convert_ops to achieve the gradietns converting
- self.evaluate(convert_ops)
- # Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], self.evaluate(var0))
- self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Run convert_ops to achieve the gradients converting
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(convert_ops)
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
- # Run 1 step of sgd through optimizer
- converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
- opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
- self.evaluate(opt_op)
+ # Run 1 step of sgd through optimizer
+ converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
+ opt_op = sgd.apply_gradients(converted_grads_and_vars)
+ self.evaluate(sgd.iteration.initializer)
+ self.evaluate(opt_op)
- # Validate updated params
- self.assertAllClose([-14., -13.], self.evaluate(var0))
- self.assertAllClose([-6., -5.], self.evaluate(var1))
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
@test_util.run_in_graph_and_eager_modes
def testComputeGradientsWithTensors(self):
- x = ops.convert_to_tensor(1.0)
- def f():
- return x * x
-
- sgd_op = gradient_descent.SGD(3.0)
- grads_and_vars = sgd_op.compute_gradients(f, [x])
- self.assertEqual(1, len(grads_and_vars))
- grad, x_as_var = grads_and_vars[0]
- self.assertIs(x, x_as_var)
- self.assertEqual(2.0, self.evaluate(grad))
-
- with self.assertRaises(NotImplementedError):
- sgd_op.apply_gradients(grads_and_vars)
-
- def testTrainOp(self):
with self.cached_session():
- var0 = variables.Variable([1.0, 2.0])
- var1 = variables.Variable([3.0, 4.0])
- cost = 5 * var0 + 3 * var1
- global_step = variables.Variable(
- array_ops.zeros([], dtypes.int64), name='global_step')
- sgd_op = gradient_descent.SGD(3.0)
- opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
- self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
+ x = ops.convert_to_tensor(1.0)
+ def f():
+ return x * x
+
+ sgd = gradient_descent.SGD(3.0)
+ grads_and_vars = sgd.compute_gradients(f, [x])
+ self.assertEqual(1, len(grads_and_vars))
+ grad, x_as_var = grads_and_vars[0]
+ self.assertIs(x, x_as_var)
+ self.assertEqual(2.0, self.evaluate(grad))
+
+ with self.assertRaises(NotImplementedError):
+ sgd.apply_gradients(grads_and_vars)
+
+ @test_util.run_in_graph_and_eager_modes
def testConstraint(self):
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
@@ -232,21 +218,29 @@ class OptimizerTest(test.TestCase):
constraint=constraint_01)
var1 = variables.Variable([3.0, 4.0],
constraint=constraint_0)
- cost = 5 * var0 + 3 * var1
- global_step = variables.Variable(
- array_ops.zeros([], dtypes.int64), name='global_step')
- sgd_op = gradient_descent.SGD(3.0)
- opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+ loss = lambda: 5 * var0 + 3 * var1
+ if not context.executing_eagerly(): # pylint: disable=cell-var-from-loop
+ loss = loss()
+ sgd = gradient_descent.SGD(3.0)
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 1 step of sgd through optimizer
- opt_op.run()
+ opt_op = sgd.minimize(loss, var_list=[var0, var1])
+ self.evaluate(sgd.iteration.initializer)
+ self.evaluate(opt_op)
# Validate updated params
- self.assertAllClose([-0.1, -0.1], var0.eval())
- self.assertAllClose([0., 0.], var1.eval())
+ self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
+ self.assertAllClose([0., 0.], self.evaluate(var1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testIterationWithoutMinimize(self):
+ with self.cached_session():
+ sgd = gradient_descent.SGD(3.0)
+ self.evaluate(sgd.iteration.initializer)
+ self.assertEqual(0, self.evaluate(sgd.iteration))
if __name__ == '__main__':