Fix the device placement test - check the variable's device, instead of the

read tensor's device.

PiperOrigin-RevId: 257447029
This commit is contained in:
Akshay Modi 2019-07-10 11:32:24 -07:00 committed by TensorFlower Gardener
parent 3b1c5e378b
commit 1dffb00347

View File

@ -580,7 +580,7 @@ class DefFunctionTest(test.TestCase):
# TODO(b/137148281): reenable
@test_util.run_gpu_only
def DISABLED_testDeviceAnnotationRespected(self):
def testDeviceAnnotationRespected(self):
a = []
@def_function.function()
@ -590,13 +590,13 @@ class DefFunctionTest(test.TestCase):
(2, 2), maxval=1000000, dtype=dtypes.int64)
if not a:
with ops.device("CPU:0"):
with ops.device('CPU:0'):
a.append(resource_variable_ops.ResourceVariable(initial_value))
return a[0].read_value()
created_variable_read = create_variable()
self.assertRegexpMatches(created_variable_read.device, "CPU")
self.assertRegexpMatches(a[0].device, 'CPU')
def testDecorate(self):
func = def_function.function(lambda: 1)