Fix the device placement test - check the variable's device, instead of the
read tensor's device. PiperOrigin-RevId: 257447029
This commit is contained in:
parent
3b1c5e378b
commit
1dffb00347
@ -580,7 +580,7 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
# TODO(b/137148281): reenable
|
||||
@test_util.run_gpu_only
|
||||
def DISABLED_testDeviceAnnotationRespected(self):
|
||||
def testDeviceAnnotationRespected(self):
|
||||
a = []
|
||||
|
||||
@def_function.function()
|
||||
@ -590,13 +590,13 @@ class DefFunctionTest(test.TestCase):
|
||||
(2, 2), maxval=1000000, dtype=dtypes.int64)
|
||||
|
||||
if not a:
|
||||
with ops.device("CPU:0"):
|
||||
with ops.device('CPU:0'):
|
||||
a.append(resource_variable_ops.ResourceVariable(initial_value))
|
||||
|
||||
return a[0].read_value()
|
||||
|
||||
created_variable_read = create_variable()
|
||||
self.assertRegexpMatches(created_variable_read.device, "CPU")
|
||||
self.assertRegexpMatches(a[0].device, 'CPU')
|
||||
|
||||
def testDecorate(self):
|
||||
func = def_function.function(lambda: 1)
|
||||
|
Loading…
Reference in New Issue
Block a user