Fix //third_party/tensorflow/python:session_test. This failure is caused by constant folding.
PiperOrigin-RevId: 351837260 Change-Id: Ia874803a85cc378a8fe0b6e89d9bbbcf928aa7b5
This commit is contained in:
parent
4133fef691
commit
7d43911faa
@ -1965,17 +1965,17 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(len(add_executions), 2)
|
self.assertEqual(len(add_executions), 2)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def fn():
|
def fn(a, b):
|
||||||
a = constant_op.constant(1)
|
|
||||||
b = constant_op.constant(2)
|
|
||||||
c = a + b
|
c = a + b
|
||||||
d = a + b
|
# These two AddV2 cannot use the same argument in tf.function since an
|
||||||
|
# optimization pass will remove duplicate ops and only run it once.
|
||||||
|
d = a + c
|
||||||
return c, d
|
return c, d
|
||||||
|
|
||||||
with CaptureStderr() as log:
|
with CaptureStderr() as log:
|
||||||
c, d = self.evaluate(fn())
|
c, d = self.evaluate(fn(constant_op.constant(1), constant_op.constant(2)))
|
||||||
self.assertEqual(c, 3)
|
self.assertEqual(c, 3)
|
||||||
self.assertEqual(d, 3)
|
self.assertEqual(d, 4)
|
||||||
# Ensure that we did log device placement.
|
# Ensure that we did log device placement.
|
||||||
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
||||||
self.assertEqual(len(add_executions), 2)
|
self.assertEqual(len(add_executions), 2)
|
||||||
|
Loading…
Reference in New Issue
Block a user