From 98fb3b92432dbc68e9170d137fc4aedf721bd1b7 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Mon, 29 Oct 2018 14:06:09 -0700 Subject: [PATCH] Implement apply gradients for keras optimizer v2. PiperOrigin-RevId: 219189655 --- tensorflow/contrib/distribute/python/BUILD | 22 + .../python/keras_optimizer_v2_test.py | 237 +++++++++++ tensorflow/python/keras/optimizer_v2/BUILD | 1 + tensorflow/python/keras/optimizer_v2/adam.py | 124 ++++++ .../optimizer_v2/gradient_descent_test.py | 28 +- .../python/keras/optimizer_v2/optimizer_v2.py | 381 ++++++++++++++++-- .../keras/optimizer_v2/optimizer_v2_test.py | 294 +++++++------- 7 files changed, 883 insertions(+), 204 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py create mode 100644 tensorflow/python/keras/optimizer_v2/adam.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 1fba094059e..daa7592bcfb 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -471,6 +471,28 @@ cuda_py_test( ], ) +cuda_py_test( + name = "keras_optimizer_v2_test", + srcs = ["keras_optimizer_v2_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/feature_column", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + cuda_py_test( name = "estimator_training_test", size = "large", diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py new file mode 100644 index 00000000000..f4c222f26c3 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -0,0 +1,237 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that show that DistributionStrategy works with canned Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import training +from tensorflow.python.estimator.canned import dnn_linear_combined +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def dataset_input_fn(self, x, y, batch_size): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(1).batch(batch_size) + return dataset + + return input_fn + + @combinations.generate( + combinations.combine( + mode=['graph'], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + use_train_and_evaluate=[True, False])) + def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + train_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // len(distribution.worker_devices)) + eval_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // len(distribution.worker_devices)) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, batch_size=batch_size, shuffle=False) + + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + feature_columns = linear_feature_columns + dnn_feature_columns + session_config = config_pb2.ConfigProto( + log_device_placement=True, allow_soft_placement=True) + estimator = dnn_linear_combined.DNNLinearCombinedRegressor( + linear_feature_columns=linear_feature_columns, + dnn_hidden_units=(2, 2), + dnn_feature_columns=dnn_feature_columns, + label_dimension=label_dimension, + model_dir=self._model_dir, + dnn_optimizer=adam.Adam(0.001), + linear_optimizer=adam.Adam(0.001), + config=run_config.RunConfig( + train_distribute=distribution, + eval_distribute=distribution, + session_config=session_config)) + + num_steps = 2 + if use_train_and_evaluate: + scores, _ = training.train_and_evaluate( + estimator, training.TrainSpec(train_input_fn, max_steps=num_steps), + training.EvalSpec(eval_input_fn)) + else: + estimator.train(train_input_fn, steps=num_steps) + scores = estimator.evaluate(eval_input_fn) + + self.assertIn('loss', six.iterkeys(scores)) + + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +class MirroredStrategyOptimizerV2Test(test.TestCase): + + def testKerasOptimizerWithUnequalInput(self): + if context.num_gpus() < 1: + self.skipTest('Not enough GPUs.') + + def create_fn(device_id): + var = variables.Variable( + 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) + # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. + loss = (device_id + 1) * var + optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) + train_op = optimizer.minimize(loss, var_list=[var]) + m = optimizer.get_slot(var, 'm') + v = optimizer.get_slot(var, 'v') + return (var, m, v, train_op, optimizer.iteration) + + devices = ['/device:GPU:0', '/device:CPU:0'] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + (var, m, v, op, counter) = dist.call_for_each_replica( + create_fn, dist.worker_device_index, run_concurrently=False) + self.evaluate(variables.global_variables_initializer()) + var_val = [2.0, 2.0, 2.0] + self.assertAllClose( + var_val, + self.evaluate( + [dist.read_var(var), + var.get(devices[0]), + var.get(devices[1])])) + self.assertAllClose([0, 0, 0], + self.evaluate([ + dist.read_var(counter), + counter.get(devices[0]), + counter.get(devices[1]) + ])) + + train_op = dist.unwrap(op) + self.evaluate(train_op) + # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 + m_val = [1.2, 1.2, 1.2] + # assert slot variables in both replicas are the same. + self.assertAllClose( + m_val, + self.evaluate( + [dist.read_var(m), + m.get(devices[0]), + m.get(devices[1])])) + # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 + v_val = [1.8, 1.8, 1.8] + self.assertAllClose( + v_val, + self.evaluate( + [dist.read_var(v), + v.get(devices[0]), + v.get(devices[1])])) + # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) + # = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8 + var_val = [1.99, 1.99, 1.99] + self.assertAllClose( + var_val, + self.evaluate( + [dist.read_var(var), + var.get(devices[0]), + var.get(devices[1])])) + self.assertAllClose([1, 1, 1], + self.evaluate([ + dist.read_var(counter), + counter.get(devices[0]), + counter.get(devices[1]) + ])) + + self.evaluate(train_op) + # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5 + m_val = [1.44, 1.44, 1.44] + self.assertAllClose( + m_val, + self.evaluate( + [dist.read_var(m), + m.get(devices[0]), + m.get(devices[1])])) + # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 + v_val = [2.16, 2.16, 2.16] + self.assertAllClose( + v_val, + self.evaluate( + [dist.read_var(v), + v.get(devices[0]), + v.get(devices[1])])) + self.assertAllClose([2, 2, 2], + self.evaluate([ + dist.read_var(counter), + counter.get(devices[0]), + counter.get(devices[1]) + ])) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index e21674ef606..f742e8aa265 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -12,6 +12,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "optimizer_v2", srcs = [ + "adam.py", "gradient_descent.py", "optimizer_v2.py", ], diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py new file mode 100644 index 00000000000..6c67cd3a61a --- /dev/null +++ b/tensorflow/python/keras/optimizer_v2/adam.py @@ -0,0 +1,124 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adam for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.ops import math_ops +from tensorflow.python.training import training_ops + + +class Adam(optimizer_v2.OptimizerV2): + """Optimizer that implements the Adam algorithm. + + Adam optimization is a stochastic gradient descent method that is based on + adaptive estimation of first-order and second-order moments. According to the + reference, the method is 'computationally efficient, has little memory + requirement, invariant to diagonal rescaling of gradients, and is well suited + for problems that are large in terms of data/parameters'. + + # References + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + name='Adam'): + r"""Construct a new Adam optimizer. + + Initialization: + + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + $$t := t + 1$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + + $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Args: + learning_rate: A Tensor or a floating point value. The learning rate. + beta_1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". @compatibility(eager) When eager execution is + enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be + a callable that takes no arguments and returns the actual value to use. + This can be useful for changing these values across different + invocations of optimizer functions. @end_compatibility + """ + + super(Adam, self).__init__(name) + self._lr = learning_rate + self._beta_1 = beta_1 + self._beta_2 = beta_2 + self._epsilon = epsilon + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + for var in var_list: + self.add_slot(var, 'm') + self.add_slot(var, 'v') + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, 'm') + v = self.get_slot(var, 'v') + # TODO(tanzheny): let optimizer have its own step counter, and let + # beta1_power and beta2_power depend on it. + return training_ops.resource_apply_adam( + var.handle, + m.handle, + v.handle, + math_ops.cast(self._beta_1, grad.dtype.base_dtype), + math_ops.cast(self._beta_2, grad.dtype.base_dtype), + math_ops.cast(self._lr, grad.dtype.base_dtype), + math_ops.cast(self._beta_1, grad.dtype.base_dtype), + math_ops.cast(self._beta_2, grad.dtype.base_dtype), + math_ops.cast(self._epsilon, grad.dtype.base_dtype), + grad, + use_locking=self._use_locking) diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py index 08b273f5562..a1f534d55f6 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py @@ -64,13 +64,13 @@ class GradientDescentOptimizerTest(test.TestCase): var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - sgd_op = gradient_descent.SGD(3.0).apply_gradients( - zip([grads0, grads1], [var0, var1])) + sgd = gradient_descent.SGD(3.0) + sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) # TODO(apassos) calling initialize_resources on all resources here # doesn't work because the sessions and graph are reused across unit # tests and this would mean trying to reinitialize variables. Figure out # a long-term solution for this. - resources.initialize_resources([var0, var1]).run() + resources.initialize_resources([var0, var1, sgd.iteration]).run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) @@ -90,13 +90,13 @@ class GradientDescentOptimizerTest(test.TestCase): grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) lr = lambda: 3.0 - sgd_op = gradient_descent.SGD(lr).apply_gradients( - zip([grads0, grads1], [var0, var1])) + sgd = gradient_descent.SGD(lr) + sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) # TODO(apassos) calling initialize_resources on all resources here # doesn't work because the sessions and graph are reused across unit # tests and this would mean trying to reinitialize variables. Figure out # a long-term solution for this. - resources.initialize_resources([var0, var1]).run() + resources.initialize_resources([var0, var1, sgd.iteration]).run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) @@ -116,12 +116,13 @@ class GradientDescentOptimizerTest(test.TestCase): x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(var0, x) + var1 loss = pred * pred - sgd_op = gradient_descent.SGD(1.0).minimize(loss) + sgd = gradient_descent.SGD(1.0) + sgd_op = sgd.minimize(loss, [var0, var1]) # TODO(apassos) calling initialize_resources on all resources here # doesn't work because the sessions and graph are reused across unit # tests and this would mean trying to reinitialize variables. Figure out # a long-term solution for this. - resources.initialize_resources([var0, var1]).run() + resources.initialize_resources([var0, var1, sgd.iteration]).run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) self.assertAllCloseAccordingToType([3.0], var1.eval()) @@ -143,7 +144,7 @@ class GradientDescentOptimizerTest(test.TestCase): pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) pred += var1 loss = pred * pred - sgd_op = gradient_descent.SGD(1.0).minimize(loss) + sgd_op = gradient_descent.SGD(1.0).minimize(loss, [var0, var1]) variables.global_variables_initializer().run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) @@ -193,25 +194,24 @@ class GradientDescentOptimizerTest(test.TestCase): def testWithGlobalStep(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): - global_step = variables.Variable(0, trainable=False) var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - sgd_op = gradient_descent.SGD(3.0).apply_gradients( - zip([grads0, grads1], [var0, var1]), global_step=global_step) + sgd = gradient_descent.SGD(3.0) + sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) # Run 1 step of sgd sgd_op.run() - # Validate updated params and global_step + # Validate updated params and optimizer iterations. self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval()) self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval()) - self.assertAllCloseAccordingToType(1, global_step.eval()) + self.assertAllCloseAccordingToType(1, sgd.iteration.eval()) def testSparseBasic(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index fd69cd0c664..c820847e53d 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,20 @@ from __future__ import print_function import abc +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras import initializers +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as optimizer_v1 +from tensorflow.python.util import nest class OptimizerV2(optimizer_v1.Optimizer): @@ -77,29 +90,6 @@ class OptimizerV2(optimizer_v1.Optimizer): opt.apply_gradients(capped_grads_and_vars) ``` - ### Gating Gradients - - Both `minimize()` and `compute_gradients()` accept a `gate_gradients` - argument that controls the degree of parallelism during the application of - the gradients. - - The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. - - `GATE_NONE`: Compute and apply gradients in parallel. This provides - the maximum parallelism in execution, at the cost of some non-reproducibility - in the results. For example the two gradients of `matmul` depend on the input - values: With `GATE_NONE` one of the gradients could be applied to one of the - inputs _before_ the other gradient is computed resulting in non-reproducible - results. - - `GATE_OP`: For each Op, make sure all gradients are computed before - they are used. This prevents race conditions for Ops that generate gradients - for multiple inputs where the gradients depend on the inputs. - - `GATE_GRAPH`: Make sure all gradients for all variables are computed - before any one of them is used. This provides the least parallelism but can - be useful if you want to process all gradients before applying any of them. - ### Slots Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer` @@ -111,11 +101,6 @@ class OptimizerV2(optimizer_v1.Optimizer): This can be useful if you want to log debug a training algorithm, report stats about the slots, etc. - ### Non-slot variables - - Some optimizer subclasses, such as `AdamOptimizer` have variables that - are not associated with the variables to train, just the step itself. - ### Hyper parameters These are arguments passed to the optimizer subclass constructor @@ -124,18 +109,8 @@ class OptimizerV2(optimizer_v1.Optimizer): callables. If they are callable, the callable will be called during `apply_gradients()` to get the value for the hyper parameter. - ### State - - Internal methods are passed a `state` argument with the correct - values to use for the slot and non-slot variables, and the hyper - parameters. """ - # Values for gate_gradients. - GATE_NONE = 0 - GATE_OP = 1 - GATE_GRAPH = 2 - def __init__(self, name): """Create a new Optimizer. @@ -145,6 +120,8 @@ class OptimizerV2(optimizer_v1.Optimizer): you should be able to use the _set_hyper()/state.get_hyper() facility instead. + This class in stateful and thread-compatible. + Args: name: A non-empty string. The name to use for accumulators created for the optimizer. @@ -157,6 +134,192 @@ class OptimizerV2(optimizer_v1.Optimizer): self._use_locking = True super(OptimizerV2, self).__init__(self._use_locking, name) self._hyper = {} + self._slots = {} + self._prepared = False + + def minimize(self, + loss, + var_list, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Add operations to minimize `loss` by updating `var_list`. + + This method simply combines calls `compute_gradients()` and + `apply_gradients()`. If you want to process the gradient before applying + them call `compute_gradients()` and `apply_gradients()` explicitly instead + of using this function. + + Args: + loss: A `Tensor` containing the value to minimize. + var_list: list or tuple of `Variable` objects to update to minimize + `loss`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with the + corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + + Raises: + ValueError: If some of the variables are not `Variable` objects. + + @compatibility(eager) + When eager execution is enabled, `loss` should be a Python function that + takes no arguments and computes the value to be minimized. Minimization (and + gradient computation) is done with respect to the elements of `var_list` if + not None, else with respect to any trainable variables created during the + execution of the `loss` function. `gate_gradients`, `aggregation_method`, + `colocate_gradients_with_ops` and `grad_loss` are ignored when eager + execution is enabled. + @end_compatibility + """ + grads_and_vars = self.compute_gradients( + loss, + var_list=var_list, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + + return self.apply_gradients(grads_and_vars, name=name) + + def compute_gradients(self, + loss, + var_list, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None, + stop_gradients=None): + """Compute gradients of `loss` for the variables in `var_list`. + + This is the first part of `minimize()`. It returns a list + of (gradient, variable) pairs where "gradient" is the gradient + for "variable". Note that "gradient" can be a `Tensor`, an + `IndexedSlices`, or `None` if there is no gradient for the + given variable. + + Args: + loss: A Tensor containing the value to minimize or a callable taking no + arguments which returns the value to minimize. When eager execution is + enabled it must be a callable. + var_list: Optional list or tuple of `tf.Variable` to update to minimize + `loss`. Defaults to the list of variables collected in the graph under + the key `GraphKeys.TRAINABLE_VARIABLES`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with the + corresponding op. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + stop_gradients: Optional. A Tensor or list of tensors not to differentiate + through. + + Returns: + A list of (gradient, variable) pairs. Variable is always present, but + gradient can be `None`. + + Raises: + TypeError: If `var_list` contains anything else than `Variable` objects. + ValueError: If some arguments are invalid, or var_list is None. + RuntimeError: If called with eager execution enabled and `loss` is + not callable. + + @compatibility(eager) + When eager execution is enabled, `aggregation_method`, and + `colocate_gradients_with_ops` are ignored. + @end_compatibility + """ + var_list = nest.flatten(var_list) + # TODO(josh11b): Test that we handle weight decay in a reasonable way. + if callable(loss): + with backprop.GradientTape() as tape: + tape.watch(var_list) + loss_value = loss() + grads = tape.gradient(loss_value, var_list, grad_loss) + else: + if context.executing_eagerly(): + raise RuntimeError("`loss` passed to Optimizer.compute_gradients " + "should be a function when eager execution is " + "enabled.") + self._assert_valid_dtypes([loss]) + if grad_loss is not None: + self._assert_valid_dtypes([grad_loss]) + grads = gradients.gradients( + loss, + var_list, + grad_ys=grad_loss, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + stop_gradients=stop_gradients) + + grads_and_vars = list(zip(grads, var_list)) + self._assert_valid_dtypes([ + v for g, v in grads_and_vars + if g is not None and v.dtype != dtypes.resource + ]) + + return grads_and_vars + + def apply_gradients(self, grads_and_vars, name=None): + """Apply gradients to variables. + + This is the second part of `minimize()`. It returns an `Operation` that + applies gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + name: Optional name for the returned operation. Default to the name + passed to the `Optimizer` constructor. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. + + Raises: + TypeError: If `grads_and_vars` is malformed. + ValueError: If none of the variables have gradients. + """ + grads_and_vars = _filter_grads(grads_and_vars) + var_list = [v for (_, v) in grads_and_vars] + if distribution_strategy_context.has_distribution_strategy(): + reduced_grads = merge_grads(grads_and_vars) + grads_and_vars = zip(reduced_grads, var_list) + + with ops.init_scope(): + self._create_slots(var_list) + update_ops = [] + + def update_grad_to_var(grad, var): + """Apply gradient to variable.""" + if isinstance(var, ops.Tensor): + raise NotImplementedError("Trying to update a Tensor ", var) + if isinstance(grad, ops.IndexedSlices): + if var.constraint is not None: + raise RuntimeError( + "Cannot use a constraint function on a sparse variable.") + return self._resource_apply_sparse_duplicate_indices( + grad.values, var, grad.indices) + update_op = self._resource_apply_dense(grad, var) + if var.constraint is not None: + with ops.control_dependencies([update_op]): + return var.assign(var.constraint(var)) + else: + return update_op + + with ops.name_scope(name, self._name) as name: + self._prepare() + for grad, var in grads_and_vars: + scope_name = "" if in_eager_execution() else "_" + var.op.name + with ops.name_scope("update" + scope_name), ops.colocate_with(var): + update_ops.append(update_grad_to_var(grad, var)) + with ops.colocate_with(self._iterations): + update_ops.append(self._iterations.assign_add(1)) + return control_flow_ops.group(*update_ops) def _set_hyper(self, name, value): self._hyper[name] = value @@ -166,8 +329,33 @@ class OptimizerV2(optimizer_v1.Optimizer): value = self._hyper[name] return self._call_if_callable(value) + def add_slot(self, var, slot_name): + slot_key = _get_slot_key_from_var(var, slot_name) + if slot_key not in self._slots: + self._slots[slot_key] = self.add_weight( + name=slot_key, shape=var.shape, dtype=var.dtype) + + def get_slot(self, var, slot_name): + slot_key = _get_slot_key_from_var(var, slot_name) + return self._slots[slot_key] + def _prepare(self): - pass + if self._prepared: + return + # This is where all hyper variables will be created. + with ops.device("cpu:0"): + self._iterations = self.add_weight( + self._name + "/iter", + shape=[], + trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + self._prepared = True + + @property + def iteration(self): + if not self._prepared: + self._prepare() + return self._iterations @abc.abstractmethod def get_config(self): @@ -205,3 +393,116 @@ class OptimizerV2(optimizer_v1.Optimizer): def _serialize_hyperparameter(self, hyperparameter_name): """Serialize a hyperparameter that can be a float, callable, or Tensor.""" return self._hyper[hyperparameter_name] + + def add_weight(self, + name, + shape, + dtype=None, + initializer="zeros", + trainable=None, + synchronization=variables.VariableSynchronization.AUTO, + aggregation=variables.VariableAggregation.NONE): + + if dtype is None: + dtype = dtypes.float32 + initializer = initializers.get(initializer) + + if synchronization == variables.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + "Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ.") + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + + variable = self._add_variable_with_custom_getter( + name=name, + shape=shape, + getter=base_layer.make_variable, + overwrite=True, + initializer=initializers.get(initializer), + dtype=dtype, + trainable=trainable, + use_resource=True, + synchronization=synchronization, + aggregation=aggregation) + + return variable + + +def _filter_grads(grads_and_vars): + """Filter out iterable with grad equal to None.""" + grads_and_vars = tuple(grads_and_vars) + if not grads_and_vars: + raise ValueError("No variables provided.") + filtered = [] + vars_with_empty_grads = [] + for grad, var in grads_and_vars: + if grad is None: + vars_with_empty_grads.append(var) + else: + filtered.append((grad, var)) + filtered = tuple(filtered) + if not filtered: + raise ValueError("No gradients provided for any variable: %s." % + ([v.name for _, v in filtered],)) + if vars_with_empty_grads: + logging.warning( + ("Gradients does not exist for variables %s when minimizing the loss."), + ([v.name for v in vars_with_empty_grads])) + return filtered + + +def merge_grads(grads_and_vars): + """Merge gradients from different replicas.""" + + def merge_grad_fn(strategy, grads_and_vars): + reduced_grads = strategy.batch_reduce( + variable_scope.VariableAggregation.MEAN, grads_and_vars) + return reduced_grads + + return distribution_strategy_context.get_tower_context().merge_call( + merge_grad_fn, grads_and_vars) + + +def in_eager_execution(): + with ops.init_scope(): + return context.executing_eagerly() + + +def _get_slot_key_from_var(var, slot_name): + """Get the slot key for the variable. + + Scope the slot name in the namespace of the primary variable. + Set "primary.op.name + '/' + slot_name" as default name. + + In graph mode the name is derived from the op. + In eager mode the name is derived from the var. + If distribution strategy exists, then the name is derived from the primary + variable instead of replica variable, i.e., /dense/kernel instead of + /dense/kernel/replica_1. If the slot name is 'm', then the slot variables + being created are /dense/kernel/m and /dense/kernel/m/replica_1, instead of + /dense/kernel/replica_1/m/replica_1. + + Args: + var: the variable. + slot_name: the name of the slot. + + Returns: + the name of the variable. + """ + + # pylint: disable=protected-access + if distribution_strategy_context.has_distribution_strategy() and hasattr( + var, "_primary_var"): + var = var._primary_var + if context.executing_eagerly(): + name = var._shared_name + else: + name = var.op.name + return name + "/" + slot_name diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index 15f33eb7baa..fe12ab204f1 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,194 +37,179 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBasic(self): - for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - def loss(): - return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop - # Note that for eager execution, minimize expects a function instead of a - # Tensor. - global_step = resource_variable_ops.ResourceVariable( - array_ops.zeros([], dtypes.int64), name='global_step_%d' % i) - sgd_op = gradient_descent.SGD(3.0) + for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.cached_session(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + if not context.executing_eagerly(): + loss = loss() + sgd = gradient_descent.SGD(3.0) - self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], self.evaluate(var0)) - self.assertAllClose([3.0, 4.0], self.evaluate(var1)) - # Run 1 step of sgd through optimizer - opt_op = sgd_op.minimize(loss, global_step, [var0, var1]) - self.evaluate(opt_op) - # Validate updated params - self.assertAllClose([-14., -13.], self.evaluate(var0)) - self.assertAllClose([-6., -5.], self.evaluate(var1)) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Run 1 step of sgd through optimizer + opt_op = sgd.minimize(loss, var_list=[var0, var1]) + self.evaluate(sgd.iteration.initializer) + self.evaluate(opt_op) + # Validate updated params + self.assertAllClose([-14., -13.], self.evaluate(var0)) + self.assertAllClose([-6., -5.], self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes def testAggregationMethod(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) - cost = 5 * var0 + 3 * var1 - global_step = variables.Variable( - array_ops.zeros([], dtypes.int64), name='global_step') - sgd_op = gradient_descent.SGD(3.0) - opt_op = sgd_op.minimize( - cost, - global_step, [var0, var1], - aggregation_method=gradients_impl.AggregationMethod. - EXPERIMENTAL_ACCUMULATE_N) + loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + if not context.executing_eagerly(): + loss = loss() + sgd = gradient_descent.SGD(3.0) - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 1 step of sgd through optimizer - opt_op.run() + opt_op = sgd.minimize( + loss, + var_list=[var0, var1], + aggregation_method=gradients_impl.AggregationMethod + .EXPERIMENTAL_ACCUMULATE_N) + self.evaluate(sgd.iteration.initializer) + self.evaluate(opt_op) # Validate updated params - self.assertAllClose([-14., -13.], var0.eval()) - self.assertAllClose([-6., -5.], var1.eval()) + self.assertAllClose([-14., -13.], self.evaluate(var0)) + self.assertAllClose([-6., -5.], self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes def testPrecomputedGradient(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) - cost = 5 * var0 + 3 * var1 + loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + if not context.executing_eagerly(): + loss = loss() grad_loss = constant_op.constant([42, -42], dtype=dtype) - global_step = variables.Variable( - array_ops.zeros([], dtypes.int64), name='global_step') - sgd_op = gradient_descent.SGD(3.0) - opt_op = sgd_op.minimize( - cost, global_step, [var0, var1], grad_loss=grad_loss) + sgd = gradient_descent.SGD(3.0) - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 1 step of sgd through optimizer - opt_op.run() + opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss) + self.evaluate(sgd.iteration.initializer) + self.evaluate(opt_op) # Validate updated params self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)], - var0.eval()) + self.evaluate(var0)) self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], - var1.eval()) - - @test_util.run_in_graph_and_eager_modes - def testNoVariables(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - # pylint: disable=cell-var-from-loop - def loss(): - var0 = resource_variable_ops.ResourceVariable( - [1.0, 2.0], dtype=dtype, trainable=False, name='a') - var1 = resource_variable_ops.ResourceVariable( - [3.0, 4.0], dtype=dtype, trainable=False, name='b') - return 5 * var0 + var1 - # pylint: enable=cell-var-from-loop - sgd_op = gradient_descent.SGD(3.0) - with self.assertRaisesRegexp(ValueError, 'No.*variables'): - sgd_op.minimize(loss) + self.evaluate(var1)) @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - # pylint: disable=cell-var-from-loop - def loss(): - return 5 * var0 - # pylint: enable=cell-var-from-loop - sgd_op = gradient_descent.SGD(3.0) - with self.assertRaisesRegexp(ValueError, 'No gradients'): - # var1 has no gradient - sgd_op.minimize(loss, var_list=[var1]) + with self.cached_session(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + loss = lambda: 5 * var0 # pylint: disable=cell-var-from-loop + if not context.executing_eagerly(): + loss = loss() + sgd_op = gradient_descent.SGD(3.0) + with self.assertRaisesRegexp(ValueError, 'No gradients'): + # var1 has no gradient + sgd_op.minimize(loss, var_list=[var1]) @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - def loss(): - return constant_op.constant(5.0) + with self.cached_session(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + loss = lambda: constant_op.constant(5.0) + if not context.executing_eagerly(): + loss = loss() - sgd_op = gradient_descent.SGD(3.0) - with self.assertRaisesRegexp(ValueError, - 'No gradients provided for any variable'): - sgd_op.minimize(loss, var_list=[var0, var1]) + sgd_op = gradient_descent.SGD(3.0) + with self.assertRaisesRegexp(ValueError, + 'No gradients provided for any variable'): + sgd_op.minimize(loss, var_list=[var0, var1]) @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - sgd_op = gradient_descent.SGD(3.0) - with self.assertRaisesRegexp(ValueError, - 'No gradients provided for any variable'): - sgd_op.apply_gradients([(None, var0), (None, var1)]) + with self.cached_session(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + sgd_op = gradient_descent.SGD(3.0) + with self.assertRaisesRegexp(ValueError, + 'No gradients provided for any variable'): + sgd_op.apply_gradients([(None, var0), (None, var1)]) @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - def loss(): - return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + with self.cached_session(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + loss = lambda: 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + if not context.executing_eagerly(): + loss = loss() - sgd_op = gradient_descent.SGD(3.0) - grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1]) - # Convert gradients to tf.Variables - converted_grads = [ - resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype), - name='c_%d_%d' % (i, j)) - for j, gv in enumerate(grads_and_vars) - ] - convert_ops = [ - state_ops.assign(converted_grads[j], gv[0]) - for j, gv in enumerate(grads_and_vars) - ] + sgd = gradient_descent.SGD(3.0) + grads_and_vars = sgd.compute_gradients(loss, [var0, var1]) + # Convert gradients to tf.Variables + converted_grads = [ + resource_variable_ops.ResourceVariable( + array_ops.zeros([2], dtype), name='c_%d_%d' % (i, j)) + for j, gv in enumerate(grads_and_vars) + ] + convert_ops = [ + state_ops.assign(converted_grads[j], gv[0]) + for j, gv in enumerate(grads_and_vars) + ] - self.evaluate(variables.global_variables_initializer()) - # Run convert_ops to achieve the gradietns converting - self.evaluate(convert_ops) - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], self.evaluate(var0)) - self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Run convert_ops to achieve the gradients converting + self.evaluate(variables.global_variables_initializer()) + self.evaluate(convert_ops) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) - # Run 1 step of sgd through optimizer - converted_grads_and_vars = list(zip(converted_grads, [var0, var1])) - opt_op = sgd_op.apply_gradients(converted_grads_and_vars) - self.evaluate(opt_op) + # Run 1 step of sgd through optimizer + converted_grads_and_vars = list(zip(converted_grads, [var0, var1])) + opt_op = sgd.apply_gradients(converted_grads_and_vars) + self.evaluate(sgd.iteration.initializer) + self.evaluate(opt_op) - # Validate updated params - self.assertAllClose([-14., -13.], self.evaluate(var0)) - self.assertAllClose([-6., -5.], self.evaluate(var1)) + # Validate updated params + self.assertAllClose([-14., -13.], self.evaluate(var0)) + self.assertAllClose([-6., -5.], self.evaluate(var1)) @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): - x = ops.convert_to_tensor(1.0) - def f(): - return x * x - - sgd_op = gradient_descent.SGD(3.0) - grads_and_vars = sgd_op.compute_gradients(f, [x]) - self.assertEqual(1, len(grads_and_vars)) - grad, x_as_var = grads_and_vars[0] - self.assertIs(x, x_as_var) - self.assertEqual(2.0, self.evaluate(grad)) - - with self.assertRaises(NotImplementedError): - sgd_op.apply_gradients(grads_and_vars) - - def testTrainOp(self): with self.cached_session(): - var0 = variables.Variable([1.0, 2.0]) - var1 = variables.Variable([3.0, 4.0]) - cost = 5 * var0 + 3 * var1 - global_step = variables.Variable( - array_ops.zeros([], dtypes.int64), name='global_step') - sgd_op = gradient_descent.SGD(3.0) - opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) - self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP)) + x = ops.convert_to_tensor(1.0) + def f(): + return x * x + + sgd = gradient_descent.SGD(3.0) + grads_and_vars = sgd.compute_gradients(f, [x]) + self.assertEqual(1, len(grads_and_vars)) + grad, x_as_var = grads_and_vars[0] + self.assertIs(x, x_as_var) + self.assertEqual(2.0, self.evaluate(grad)) + + with self.assertRaises(NotImplementedError): + sgd.apply_gradients(grads_and_vars) + + @test_util.run_in_graph_and_eager_modes def testConstraint(self): constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.) constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.) @@ -232,21 +218,29 @@ class OptimizerTest(test.TestCase): constraint=constraint_01) var1 = variables.Variable([3.0, 4.0], constraint=constraint_0) - cost = 5 * var0 + 3 * var1 - global_step = variables.Variable( - array_ops.zeros([], dtypes.int64), name='global_step') - sgd_op = gradient_descent.SGD(3.0) - opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) + loss = lambda: 5 * var0 + 3 * var1 + if not context.executing_eagerly(): # pylint: disable=cell-var-from-loop + loss = loss() + sgd = gradient_descent.SGD(3.0) - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 1 step of sgd through optimizer - opt_op.run() + opt_op = sgd.minimize(loss, var_list=[var0, var1]) + self.evaluate(sgd.iteration.initializer) + self.evaluate(opt_op) # Validate updated params - self.assertAllClose([-0.1, -0.1], var0.eval()) - self.assertAllClose([0., 0.], var1.eval()) + self.assertAllClose([-0.1, -0.1], self.evaluate(var0)) + self.assertAllClose([0., 0.], self.evaluate(var1)) + + @test_util.run_in_graph_and_eager_modes + def testIterationWithoutMinimize(self): + with self.cached_session(): + sgd = gradient_descent.SGD(3.0) + self.evaluate(sgd.iteration.initializer) + self.assertEqual(0, self.evaluate(sgd.iteration)) if __name__ == '__main__':