Colocate slot variables with their base variable in keras OptimizerV2

Not doing so results in errors because ops such as ResourceApplyRMSProp
take base as well as slot variables. Such ops cannot be placed when the
variable devices are different.

PiperOrigin-RevId: 226924297
This commit is contained in:
Igor Ganichev 2018-12-26 09:05:31 -08:00 committed by TensorFlower Gardener
parent 2fc6decf7b
commit 2696015d85
4 changed files with 108 additions and 5 deletions

View File

@ -25,10 +25,12 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.optimizer_v2 import rmsprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@ -448,5 +450,56 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
]), var1.eval())
class SlotColocationTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters([True, False])
@test_util.run_in_graph_and_eager_modes
def testRunMinimizeOnGPUForCPUVariables(self, use_resource):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with ops.device("/device:CPU:0"):
if use_resource:
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
dtype=dtypes.float32)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
dtype=dtypes.float32)
global_step = resource_variable_ops.ResourceVariable(
array_ops.zeros([], dtypes.int64), name="global_step")
else:
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
global_step = variables.Variable(
array_ops.zeros([], dtypes.int64), name="global_step")
def loss():
return 5 * var0 + 3 * var1
opt = rmsprop.RMSPropOptimizer(
learning_rate=1.0, decay=0.9, momentum=0.5, epsilon=1.0)
# Fetch params to validate initial values
self.evaluate(variables.global_variables_initializer())
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 1 step through optimizer on GPU.
# Slot variables are created the first time optimizer is used on some
# variable. This tests that slot variables will be colocated with the base
# variable.
with ops.device("/device:GPU:0"):
# Note that for eager execution, minimize expects a function instead of a
# Tensor.
opt_op = opt.minimize(loss, global_step, [var0, var1])
self.evaluate(variables.global_variables_initializer())
self.evaluate(opt_op)
# Validate updated params, All variables should have decreased.
self.assertTrue(all(v < 0.0 for v in self.evaluate(var0)),
msg="updated variables: %s" % self.evaluate(var0))
self.assertTrue(all(v < 2.0 for v in self.evaluate(var1)),
msg="updated variables: %s" % self.evaluate(var1))
if __name__ == "__main__":
test.main()

View File

@ -201,6 +201,7 @@ cuda_py_test(
srcs = ["rmsprop_test.py"],
additional_deps = [
":optimizer_v2",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python/eager:def_function",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",

View File

@ -443,11 +443,12 @@ class OptimizerV2(checkpointable.CheckpointableBase):
initializer, shape=var.shape, dtype=var.dtype)
else:
initial_value = initializer
weight = tf_variables.Variable(
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
dtype=var.dtype,
trainable=False,
initial_value=initial_value)
with ops._colocate_with_for_gradient(var, None): # pylint: disable=protected-access
weight = tf_variables.Variable(
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
dtype=var.dtype,
trainable=False,
initial_value=initial_value)
backend.track_variable(weight)
slot_dict[slot_name] = weight
self._restore_slot_variable(

View File

@ -22,6 +22,7 @@ import copy
import itertools
import math
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
@ -421,5 +422,52 @@ class RMSpropOptimizerTest(test.TestCase):
self.assertEqual(opt_3.lr, 0.1)
class SlotColocationTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters([True, False])
@test_util.run_in_graph_and_eager_modes
def testRunMinimizeOnGPUForCPUVariables(self, use_resource):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with ops.device("/device:CPU:0"):
if use_resource:
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
dtype=dtypes.float32)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
dtype=dtypes.float32)
else:
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
def loss():
return 5 * var0 + 3 * var1
opt = rmsprop.RMSprop(
learning_rate=1.0, decay=0.9, momentum=0.5, epsilon=1.0)
# Fetch params to validate initial values
self.evaluate(variables.global_variables_initializer())
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 1 step through optimizer on GPU.
# Slot variables are created the first time optimizer is used on some
# variable. This tests that slot variables will be colocated with the base
# variable.
with ops.device("/device:GPU:0"):
# Note that for eager execution, minimize expects a function instead of a
# Tensor.
opt_op = opt.minimize(loss, [var0, var1])
self.evaluate(variables.global_variables_initializer())
self.evaluate(opt_op)
# Validate updated params, All variables should have decreased.
self.assertTrue(all(v < 0.0 for v in self.evaluate(var0)),
msg="updated variables: %s" % self.evaluate(var0))
self.assertTrue(all(v < 2.0 for v in self.evaluate(var1)),
msg="updated variables: %s" % self.evaluate(var1))
if __name__ == "__main__":
test.main()