Colocate slot variables with their base variable in keras OptimizerV2
Not doing so results in errors because ops such as ResourceApplyRMSProp take base as well as slot variables. Such ops cannot be placed when the variable devices are different. PiperOrigin-RevId: 226924297
This commit is contained in:
parent
2fc6decf7b
commit
2696015d85
@ -25,10 +25,12 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.optimizer_v2 import rmsprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -448,5 +450,56 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
]), var1.eval())
|
||||
|
||||
|
||||
class SlotColocationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testRunMinimizeOnGPUForCPUVariables(self, use_resource):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
|
||||
with ops.device("/device:CPU:0"):
|
||||
if use_resource:
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
||||
dtype=dtypes.float32)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
|
||||
dtype=dtypes.float32)
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
array_ops.zeros([], dtypes.int64), name="global_step")
|
||||
else:
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
|
||||
global_step = variables.Variable(
|
||||
array_ops.zeros([], dtypes.int64), name="global_step")
|
||||
|
||||
def loss():
|
||||
return 5 * var0 + 3 * var1
|
||||
|
||||
opt = rmsprop.RMSPropOptimizer(
|
||||
learning_rate=1.0, decay=0.9, momentum=0.5, epsilon=1.0)
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
# Run 1 step through optimizer on GPU.
|
||||
# Slot variables are created the first time optimizer is used on some
|
||||
# variable. This tests that slot variables will be colocated with the base
|
||||
# variable.
|
||||
with ops.device("/device:GPU:0"):
|
||||
# Note that for eager execution, minimize expects a function instead of a
|
||||
# Tensor.
|
||||
opt_op = opt.minimize(loss, global_step, [var0, var1])
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(opt_op)
|
||||
|
||||
# Validate updated params, All variables should have decreased.
|
||||
self.assertTrue(all(v < 0.0 for v in self.evaluate(var0)),
|
||||
msg="updated variables: %s" % self.evaluate(var0))
|
||||
self.assertTrue(all(v < 2.0 for v in self.evaluate(var1)),
|
||||
msg="updated variables: %s" % self.evaluate(var1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -201,6 +201,7 @@ cuda_py_test(
|
||||
srcs = ["rmsprop_test.py"],
|
||||
additional_deps = [
|
||||
":optimizer_v2",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
|
@ -443,6 +443,7 @@ class OptimizerV2(checkpointable.CheckpointableBase):
|
||||
initializer, shape=var.shape, dtype=var.dtype)
|
||||
else:
|
||||
initial_value = initializer
|
||||
with ops._colocate_with_for_gradient(var, None): # pylint: disable=protected-access
|
||||
weight = tf_variables.Variable(
|
||||
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
|
||||
dtype=var.dtype,
|
||||
|
@ -22,6 +22,7 @@ import copy
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -421,5 +422,52 @@ class RMSpropOptimizerTest(test.TestCase):
|
||||
self.assertEqual(opt_3.lr, 0.1)
|
||||
|
||||
|
||||
class SlotColocationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testRunMinimizeOnGPUForCPUVariables(self, use_resource):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
|
||||
with ops.device("/device:CPU:0"):
|
||||
if use_resource:
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
||||
dtype=dtypes.float32)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
|
||||
dtype=dtypes.float32)
|
||||
else:
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
|
||||
|
||||
def loss():
|
||||
return 5 * var0 + 3 * var1
|
||||
|
||||
opt = rmsprop.RMSprop(
|
||||
learning_rate=1.0, decay=0.9, momentum=0.5, epsilon=1.0)
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
# Run 1 step through optimizer on GPU.
|
||||
# Slot variables are created the first time optimizer is used on some
|
||||
# variable. This tests that slot variables will be colocated with the base
|
||||
# variable.
|
||||
with ops.device("/device:GPU:0"):
|
||||
# Note that for eager execution, minimize expects a function instead of a
|
||||
# Tensor.
|
||||
opt_op = opt.minimize(loss, [var0, var1])
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(opt_op)
|
||||
|
||||
# Validate updated params, All variables should have decreased.
|
||||
self.assertTrue(all(v < 0.0 for v in self.evaluate(var0)),
|
||||
msg="updated variables: %s" % self.evaluate(var0))
|
||||
self.assertTrue(all(v < 2.0 for v in self.evaluate(var1)),
|
||||
msg="updated variables: %s" % self.evaluate(var1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user